Skip to content

Commit 800fedb

Browse files
author
Ted Themistokleous
committed
Add onnxruntime API hooks for save/load of MIGraphX models
1 parent d8d4bd6 commit 800fedb

File tree

5 files changed

+76
-1
lines changed

5 files changed

+76
-1
lines changed

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,10 @@ typedef struct OrtMIGraphXProviderOptions {
614614
int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true
615615
int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true
616616
const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name
617+
int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, noznero = true
618+
const char* migraphx_save_model_path; // migraphx model path name
619+
int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true
620+
const char* migraphx_load_model_path; // migraphx model path name
617621
} OrtMIGraphXProviderOptions;
618622

619623
/** \brief OpenVINO Provider Options

onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ constexpr const char* kFp16Enable = "trt_fp16_enable";
1717
constexpr const char* kInt8Enable = "migx_int8_enable";
1818
constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name";
1919
constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table";
20+
constexpr const char* kSaveCompiledModel = "migx_save_compiled_model";
21+
constexpr const char* kSaveModelPath = "migx_save_model_name";
22+
constexpr const char* kLoadCompiledModel = "migx_load_compiled_model";
23+
constexpr const char* kLoadModelPath = "migx_load_model_name";
24+
2025

2126
} // namespace provider_option_names
2227
} // namespace migraphx
@@ -39,6 +44,8 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions
3944
})
4045
.AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable)
4146
.AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable)
47+
.AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model)
48+
.AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model)
4249
.Parse(options));
4350

4451
return info;
@@ -49,6 +56,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE
4956
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
5057
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)},
5158
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
59+
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)},
60+
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)},
5261
};
5362
return options;
5463
}
@@ -58,6 +67,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap
5867
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
5968
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)},
6069
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
70+
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)},
71+
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)},
6172
};
6273
return options;
6374
}

onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ struct MIGraphXExecutionProviderInfo {
1919
bool int8_enable{false};
2020
std::string int8_calibration_table_name{""};
2121
bool int8_use_native_calibration_table{false};
22+
bool save_compiled_model{false};
23+
std::string save_model_file("");
24+
bool load_compiled_model{false};
25+
std::string load_model_file("");
2226

2327
static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
2428
static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info);

onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ struct MIGraphX_Provider : Provider {
5353
info.int8_calibration_table_name = options.migraphx_int8_calibration_table_name;
5454
}
5555
info.int8_use_native_calibration_table = options.migraphx_use_native_calibration_table != 0;
56+
info.save_compiled_model = options.migraphx_save_compiled_model;
57+
info.save_model_file = "";
58+
if (options.migraphx_save_compiled_model_path != nullptr) {
59+
info.save_model_file = options.migraphx_save_compiled_model_path;
60+
}
61+
info.load_compiled_model = options.migraphx_load_compiled_model;
62+
info.load_model_file = "";
63+
if (options.migraphx_load_compiled_model_path != nullptr) {
64+
info.load_model_file = options.migraphx_load_compiled_model_path;
65+
}
5666
return std::make_shared<MIGraphXProviderFactory>(info);
5767
}
5868

@@ -79,6 +89,11 @@ struct MIGraphX_Provider : Provider {
7989
}
8090

8191
migx_options.migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table;
92+
93+
migx_options.migraphx_save_compiled_model = internal_options.migraphx_save_compiled_model;
94+
migx_options.migraphx_save_model_path = internal_options.migraphx_save_compiled_model_path;
95+
migx_options.migraphx_load_compiled_model = internal_options.migraphx_load_compiled_model;
96+
migx_options.migraphx_load_model_path = internal_options.migraphx_load_compiled_model_path;
8297
}
8398

8499
ProviderOptions GetProviderOptions(const void* provider_options) override {

onnxruntime/python/onnxruntime_pybind_state.cc

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,8 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
798798
} else if (type == kMIGraphXExecutionProvider) {
799799
#ifdef USE_MIGRAPHX
800800
std::string calibration_table;
801+
std::string save_model_path;
802+
std::string load_model_path;
801803
auto it = provider_options_map.find(type);
802804
if (it != provider_options_map.end()) {
803805
OrtMIGraphXProviderOptions params{
@@ -852,7 +854,46 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
852854
"[ERROR] [MIGraphX] The value for the key 'migx_int8_use_native_calibration_table' should be"
853855
" 'True' or 'False'. Default value is 'False'.\n");
854856
}
855-
} else {
857+
} else if(option.first == "migraphx_save_compiled_model") {
858+
if (option.second == "True" || option.second == "true") {
859+
params.migraphx_fp16_enable = true;
860+
} else if (option.second == "False" || option.second == "false") {
861+
params.migraphx_fp16_enable = false;
862+
} else {
863+
ORT_THROW(
864+
"[ERROR] [MIGraphX] The value for the key 'migx_save_compiled_model' should be"
865+
" 'True' or 'False'. Default value is 'False'.\n");
866+
}
867+
} else if(option.first == "migraphx_save_model_path") {
868+
if (!option.second.empty()) {
869+
save_model_path = option.second;
870+
params.migraphx_save_compiled_model_path = save_model_path.c_str();
871+
} else {
872+
ORT_THROW(
873+
"[ERROR] [MIGraphX] The value for the key 'migx_save_model_name' should be a "
874+
"file name i.e. 'model.mxr'.\n");
875+
}
876+
} else if(option.first == "migraphx_load_compiled_model") {
877+
if (option.second == "True" || option.second == "true") {
878+
params.migraphx_fp16_enable = true;
879+
} else if (option.second == "False" || option.second == "false") {
880+
params.migraphx_fp16_enable = false;
881+
} else {
882+
ORT_THROW(
883+
"[ERROR] [MIGraphX] The value for the key 'migx_load_compiled_model' should be"
884+
" 'True' or 'False'. Default value is 'False'.\n");
885+
}
886+
} else if(option.first == "migraphx_load_model_path") {
887+
if (!option.second.empty()) {
888+
load_model_path = option.second;
889+
params.migraphx_load_compiled_model_path = load_model_path.c_str();
890+
} else {
891+
ORT_THROW(
892+
"[ERROR] [MIGraphX] The value for the key 'migx_load_model_name' should be a "
893+
"file name i.e. 'model.mxr'.\n");
894+
}
895+
}
896+
else {
856897
ORT_THROW("Invalid MIGraphX EP option: ", option.first);
857898
}
858899
}

0 commit comments

Comments
 (0)