Skip to content

Commit 6502957

Browse files
Ted ThemistokleousTedThemistokleous
authored andcommitted
Use save/load in the MIGraphX EP
1 parent fd7f121 commit 6502957

File tree

1 file changed

+72
-40
lines changed

1 file changed

+72
-40
lines changed

onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

Lines changed: 72 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,32 +1151,48 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
11511151
migraphx::program prog;
11521152

11531153
if (!no_input_shape) {
1154-
prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options);
1155-
if (fp16_enable_) {
1156-
migraphx::quantize_fp16(prog);
1157-
}
11581154

1159-
// Read in the calibration data and map it to an migraphx paramater map for the calibration ops
1160-
if (int8_enable_ && int8_calibration_cache_available_) {
1161-
migraphx::quantize_int8_options quant_opts;
1162-
migraphx::program_parameters quant_params;
11631155

1164-
auto param_shapes = prog.get_parameter_shapes();
1156+
if(!load_compiled_model)
1157+
{
1158+
prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options);
1159+
if (fp16_enable_) {
1160+
migraphx::quantize_fp16(prog);
1161+
}
11651162

1166-
for (auto&& name : param_shapes.names()) {
1167-
auto dynamic_range_i = dynamic_range_map.find(name);
1168-
if (dynamic_range_i != dynamic_range_map.end()) {
1169-
quant_params.add(name, migraphx::argument(param_shapes[name], &(dynamic_range_i->second)));
1163+
// Read in the calibration data and map it to an migraphx paramater map for the calibration ops
1164+
if (int8_enable_ && int8_calibration_cache_available_) {
1165+
migraphx::quantize_int8_options quant_opts;
1166+
migraphx::program_parameters quant_params;
1167+
1168+
auto param_shapes = prog.get_parameter_shapes();
1169+
1170+
for (auto&& name : param_shapes.names()) {
1171+
auto dynamic_range_i = dynamic_range_map.find(name);
1172+
if (dynamic_range_i != dynamic_range_map.end()) {
1173+
quant_params.add(name, migraphx::argument(param_shapes[name], &(dynamic_range_i->second)));
1174+
}
11701175
}
1171-
}
11721176

1173-
quant_opts.add_calibration_data(quant_params);
1174-
// perform static quantization on the programs
1175-
migraphx::quantize_int8(prog, t_, quant_opts);
1177+
quant_opts.add_calibration_data(quant_params);
1178+
// perform static quantization on the programs
1179+
migraphx::quantize_int8(prog, t_, quant_opts);
1180+
}
1181+
migraphx::compile_options co;
1182+
co.set_fast_math(false);
1183+
prog.compile(t_, co);
1184+
if (save_compiled_mode)
1185+
{
1186+
migraphx::file_options fo;
1187+
fo.set_file_format("msgpack");
1188+
migraphx::save(prog, save_compiled_path, fo);
1189+
}
11761190
}
1177-
migraphx::compile_options co;
1178-
co.set_fast_math(false);
1179-
prog.compile(t_, co);
1191+
else
1192+
{
1193+
prog = migraphx::load(load_compiled_path);
1194+
}
1195+
11801196
auto prog_output_shapes = prog.get_output_shapes();
11811197
for (std::size_t i = 0; i < output_names.size(); ++i) {
11821198
auto out_len = prog_output_shapes[i].lengths();
@@ -1196,7 +1212,9 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
11961212
*p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name],
11971213
map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_,
11981214
map_no_input_shape_[context->node_name], fp16_enable_, int8_enable_,
1199-
int8_calibration_cache_available_, dynamic_range_map, dump_model_ops_};
1215+
int8_calibration_cache_available_, dynamic_range_map,
1216+
save_compiled_mode_, save_compiled_path_,
1217+
load_compiled_mode_, load_compiled_path_, dump_model_ops_};
12001218
*state = p.release();
12011219
return 0;
12021220
};
@@ -1270,33 +1288,47 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
12701288
// input shapes are different, needs to re-parse onnx and
12711289
// re-compile the program
12721290
if (!input_shape_match) {
1273-
prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options);
1274-
if (fp16_enable) {
1275-
migraphx::quantize_fp16(prog);
1276-
}
1291+
if(!load_compiled_model)
1292+
{
1293+
prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options);
1294+
if (fp16_enable) {
1295+
migraphx::quantize_fp16(prog);
1296+
}
12771297

1278-
// Read in the calibration data and map it to an migraphx paramater map for the calibration ops
1279-
if (int8_enable && int8_calibration_cache_available) {
1280-
migraphx::quantize_int8_options quant_opts;
1281-
migraphx::program_parameters quant_params;
1298+
// Read in the calibration data and map it to an migraphx paramater map for the calibration ops
1299+
if (int8_enable && int8_calibration_cache_available) {
1300+
migraphx::quantize_int8_options quant_opts;
1301+
migraphx::program_parameters quant_params;
12821302

1283-
auto param_shapes = prog.get_parameter_shapes();
1303+
auto param_shapes = prog.get_parameter_shapes();
12841304

1285-
for (auto&& name : param_shapes.names()) {
1286-
auto dynamic_range_i = map_dynamic_range.find(name);
1287-
if (dynamic_range_i != map_dynamic_range.end()) {
1288-
quant_params.add(name, migraphx::argument(param_shapes[name], &(dynamic_range_i->second)));
1305+
for (auto&& name : param_shapes.names()) {
1306+
auto dynamic_range_i = map_dynamic_range.find(name);
1307+
if (dynamic_range_i != map_dynamic_range.end()) {
1308+
quant_params.add(name, migraphx::argument(param_shapes[name], &(dynamic_range_i->second)));
1309+
}
12891310
}
1311+
1312+
quant_opts.add_calibration_data(quant_params);
1313+
// perform static quantization on the programs
1314+
migraphx::quantize_int8(prog, t, quant_opts);
12901315
}
12911316

1292-
quant_opts.add_calibration_data(quant_params);
1293-
// perform static quantization on the programs
1294-
migraphx::quantize_int8(prog, t, quant_opts);
1317+
migraphx::compile_options co;
1318+
co.set_fast_math(false);
1319+
prog.compile(t, co);
1320+
if (save_compiled_mode)
1321+
{
1322+
migraphx::file_options fo;
1323+
fo.set_file_format("msgpack");
1324+
migraphx::save(prog, save_compiled_path, fo);
1325+
}
1326+
}
1327+
else
1328+
{
1329+
prog = migraphx::load(load_compiled_path);
12951330
}
12961331

1297-
migraphx::compile_options co;
1298-
co.set_fast_math(false);
1299-
prog.compile(t, co);
13001332
mgx_state->prog = prog;
13011333
param_shapes = prog.get_parameter_shapes();
13021334
no_input_shape = false;

0 commit comments

Comments
 (0)