Skip to content

Commit eaf9da7

Browse files
Changes after rebase ontop of fixes for mixed parameter fixes
1 parent ba9e209 commit eaf9da7

File tree

4 files changed

+29
-28
lines changed

4 files changed

+29
-28
lines changed

onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
153153
}
154154
}
155155

156-
//Save/load migraphx compiled models
156+
// Save/load migraphx compiled models
157157
const std::string save_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSaveCompiledModel);
158158
if (!save_comp_model_env.empty()) {
159159
save_compiled_model_ = (std::stoi(save_comp_model_env) == 0 ? false : true);
@@ -1151,8 +1151,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
11511151
migraphx::program prog;
11521152

11531153
if (!no_input_shape) {
1154-
if(!load_compiled_model_)
1155-
{
1154+
if (!load_compiled_model_) {
11561155
LOGS_DEFAULT(INFO) << "No Input shapes detected quantizing model" << std::endl;
11571156
prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options);
11581157

@@ -1192,18 +1191,17 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
11921191
prog.compile(t_, co);
11931192
LOGS_DEFAULT(INFO) << "Model Compile: Complete" << std::endl;
11941193

1195-
if (save_compiled_model_)
1196-
{
1194+
if (save_compiled_model_) {
11971195
LOGS_DEFAULT(INFO) << "Model Save: Begin" << std::endl;
11981196
migraphx::file_options fo;
11991197
fo.set_file_format("msgpack");
12001198
migraphx::save(prog, save_compiled_path_.c_str(), fo);
12011199
LOGS_DEFAULT(INFO) << "Model Save: Complete" << std::endl;
12021200
}
1203-
}
1204-
else
1205-
{
1201+
} else {
1202+
LOGS_DEFAULT(INFO) << "Model Load: Attempting to Load Pre-Compiled Model" << std::endl;
12061203
prog = migraphx::load(load_compiled_path_.c_str());
1204+
LOGS_DEFAULT(INFO) << "Model Load: Complete" << std::endl;
12071205
}
12081206

12091207
auto prog_output_shapes = prog.get_output_shapes();
@@ -1303,8 +1301,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
13031301
// input shapes are different, needs to re-parse onnx and
13041302
// re-compile the program
13051303
if (!input_shape_match) {
1306-
if(!load_compiled_model_)
1307-
{
1304+
if (!load_compiled_model_) {
13081305
LOGS_DEFAULT(VERBOSE) << "No Input shapes mismatch detected. Recompiling" << std::endl;
13091306
prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options);
13101307

@@ -1363,16 +1360,14 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
13631360
prog.compile(t, co);
13641361

13651362
LOGS_DEFAULT(INFO) << "Model Compile: Completed" << std::endl;
1366-
if (save_compiled_model_)
1367-
{
1363+
if (save_compiled_model_) {
13681364
LOGS_DEFAULT(INFO) << "Model Save: Begin" << std::endl;
13691365
migraphx::file_options fo;
13701366
fo.set_file_format("msgpack");
13711367
migraphx::save(prog, save_compiled_path_.c_str(), fo);
13721368
LOGS_DEFAULT(INFO) << "Model Save: Completed" << std::endl;
13731369
}
1374-
else
1375-
{
1370+
} else {
13761371
prog = migraphx::load(load_compiled_path_.c_str());
13771372
}
13781373

@@ -1462,7 +1457,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
14621457
HIP_CALL_THROW(hipMemcpy(output_data, gpu_res.data(), res_shape.bytes(), hipMemcpyDeviceToDevice));
14631458
}
14641459
}
1465-
}
1460+
};
14661461

14671462
return Status::OK();
14681463
};

onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ constexpr const char* kSaveModelPath = "migx_save_model_name";
2222
constexpr const char* kLoadCompiledModel = "migx_load_compiled_model";
2323
constexpr const char* kLoadModelPath = "migx_load_model_name";
2424

25-
2625
} // namespace provider_option_names
2726
} // namespace migraphx
2827

onnxruntime/python/onnxruntime_pybind_state.cc

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,11 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
832832
0,
833833
0,
834834
0,
835-
nullptr};
835+
nullptr,
836+
1,
837+
"./compiled_model.mxr",
838+
1,
839+
"./compiled_model.mxr"};
836840
for (auto option : it->second) {
837841
if (option.first == "device_id") {
838842
if (!option.second.empty()) {
@@ -879,7 +883,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
879883
"[ERROR] [MIGraphX] The value for the key 'migx_int8_use_native_calibration_table' should be"
880884
" 'True' or 'False'. Default value is 'False'.\n");
881885
}
882-
} else if(option.first == "migraphx_save_compiled_model") {
886+
} else if (option.first == "migraphx_save_compiled_model") {
883887
if (option.second == "True" || option.second == "true") {
884888
params.migraphx_fp16_enable = true;
885889
} else if (option.second == "False" || option.second == "false") {
@@ -889,16 +893,16 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
889893
"[ERROR] [MIGraphX] The value for the key 'migx_save_compiled_model' should be"
890894
" 'True' or 'False'. Default value is 'False'.\n");
891895
}
892-
} else if(option.first == "migraphx_save_model_path") {
896+
} else if (option.first == "migraphx_save_model_path") {
893897
if (!option.second.empty()) {
894898
save_model_path = option.second;
895-
params.migraphx_save_compiled_model_path = save_model_path.c_str();
899+
params.migraphx_save_model_path = save_model_path.c_str();
896900
} else {
897901
ORT_THROW(
898902
"[ERROR] [MIGraphX] The value for the key 'migx_save_model_name' should be a "
899-
"file name i.e. 'model.mxr'.\n");
903+
"file name i.e. 'compiled_model.mxr'.\n");
900904
}
901-
} else if(option.first == "migraphx_load_compiled_model") {
905+
} else if (option.first == "migraphx_load_compiled_model") {
902906
if (option.second == "True" || option.second == "true") {
903907
params.migraphx_fp16_enable = true;
904908
} else if (option.second == "False" || option.second == "false") {
@@ -908,17 +912,16 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
908912
"[ERROR] [MIGraphX] The value for the key 'migx_load_compiled_model' should be"
909913
" 'True' or 'False'. Default value is 'False'.\n");
910914
}
911-
} else if(option.first == "migraphx_load_model_path") {
915+
} else if (option.first == "migraphx_load_model_path") {
912916
if (!option.second.empty()) {
913917
load_model_path = option.second;
914-
params.migraphx_load_compiled_model_path = load_model_path.c_str();
918+
params.migraphx_load_model_path = load_model_path.c_str();
915919
} else {
916920
ORT_THROW(
917921
"[ERROR] [MIGraphX] The value for the key 'migx_load_model_name' should be a "
918-
"file name i.e. 'model.mxr'.\n");
922+
"file name i.e. 'compiled_model.mxr'.\n");
919923
}
920-
}
921-
else {
924+
} else {
922925
ORT_THROW("Invalid MIGraphX EP option: ", option.first);
923926
}
924927
}

onnxruntime/test/util/default_providers.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,11 @@ std::unique_ptr<IExecutionProvider> DefaultMIGraphXExecutionProvider() {
7676
0,
7777
0,
7878
0,
79-
nullptr};
79+
nullptr,
80+
1,
81+
"./compiled_model.mxr",
82+
1,
83+
"./compiled_model.mxr"};
8084
return MIGraphXProviderFactoryCreator::Create(&params)->CreateProvider();
8185
#else
8286
return nullptr;

0 commit comments

Comments
 (0)