Skip to content

Commit 11e7a1b

Browse files
TedThemistokleousTed Themistokleous
andauthored
[MIGraphX EP] Add migraphx ep save load compiles (#20643)
### Description Adds the ability for MIGraphX EP to save off or load compiled models to save time between inferences. Via Command line User should be able to set the save ability with ORT_MIGRAPHX_SAVE_COMPILED_MODEL ORT_MIGRAPHX_SAVE_COMPILE_PATH User should be able to set the load ability with ORT_MIGRAPHX_LOAD_COMPILED_MODEL ORT_MIGRAPHX_LOAD_COMPILE_PATH via Onnxruntime API migx_save_compiled_model migx_save_model_name migx_load_compiled_model migx_load_model_name ### Motivation and Context The motivation for this is to leverage MIGraphX's existing API to save/load models after our compile step of graph optimization. For larger models or models which were compiled with additional tuning steps, this saves time after first compile and inference run, and thus speeds up the user experience in order to encourage development. --------- Co-authored-by: Ted Themistokleous <[email protected]>
1 parent d4470fe commit 11e7a1b

File tree

8 files changed

+255
-48
lines changed

8 files changed

+255
-48
lines changed

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,10 @@ typedef struct OrtMIGraphXProviderOptions {
617617
int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true
618618
int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true
619619
const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name
620+
int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, noznero = true
621+
const char* migraphx_save_model_path; // migraphx model path name
622+
int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true
623+
const char* migraphx_load_model_path; // migraphx model path name
620624
} OrtMIGraphXProviderOptions;
621625

622626
/** \brief OpenVINO Provider Options

onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

Lines changed: 159 additions & 46 deletions
Large diffs are not rendered by default.

onnxruntime/core/providers/migraphx/migraphx_execution_provider.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS";
2626
static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME";
2727
static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH";
2828
static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE";
29+
static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL";
30+
static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILE_PATH";
31+
static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL";
32+
static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILE_PATH";
33+
2934
}; // namespace migraphx_env_vars
3035

3136
// Information to construct kernel function state.
@@ -44,6 +49,10 @@ struct MIGraphXFuncState {
4449
bool int8_enable = false;
4550
bool int8_calibration_cache_available = false;
4651
std::unordered_map<std::string, float> dynamic_range_map;
52+
bool save_compiled_mode = false;
53+
std::string save_compiled_path;
54+
bool load_compiled_mode = false;
55+
std::string load_compiled_path;
4756
bool dump_model_ops = false;
4857
};
4958

@@ -84,6 +93,10 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
8493
bool int8_use_native_migraphx_calibration_table_ = false;
8594
std::string calibration_cache_path_;
8695
std::unordered_map<std::string, float> dynamic_range_map;
96+
bool save_compiled_model_ = false;
97+
std::string save_compiled_path_;
98+
bool load_compiled_model_ = false;
99+
std::string load_compiled_path_;
87100
bool dump_model_ops_ = false;
88101
int device_id_;
89102
migraphx::target t_;

onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ constexpr const char* kFp16Enable = "trt_fp16_enable";
1717
constexpr const char* kInt8Enable = "migx_int8_enable";
1818
constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name";
1919
constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table";
20+
constexpr const char* kSaveCompiledModel = "migx_save_compiled_model";
21+
constexpr const char* kSaveModelPath = "migx_save_model_name";
22+
constexpr const char* kLoadCompiledModel = "migx_load_compiled_model";
23+
constexpr const char* kLoadModelPath = "migx_load_model_name";
2024

2125
} // namespace provider_option_names
2226
} // namespace migraphx
@@ -39,6 +43,8 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions
3943
})
4044
.AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable)
4145
.AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable)
46+
.AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model)
47+
.AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model)
4248
.Parse(options));
4349

4450
return info;
@@ -49,6 +55,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE
4955
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
5056
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)},
5157
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
58+
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)},
59+
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)},
5260
};
5361
return options;
5462
}
@@ -58,6 +66,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap
5866
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
5967
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)},
6068
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
69+
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)},
70+
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)},
6171
};
6272
return options;
6373
}

onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ struct MIGraphXExecutionProviderInfo {
1919
bool int8_enable{false};
2020
std::string int8_calibration_table_name{""};
2121
bool int8_use_native_calibration_table{false};
22+
bool save_compiled_model{true};
23+
std::string save_model_file{"./compiled_model.mxr"};
24+
bool load_compiled_model{true};
25+
std::string load_model_file{"./compiled_model.mxr"};
2226

2327
static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
2428
static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info);

onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ struct MIGraphX_Provider : Provider {
5353
info.int8_calibration_table_name = options.migraphx_int8_calibration_table_name;
5454
}
5555
info.int8_use_native_calibration_table = options.migraphx_use_native_calibration_table != 0;
56+
info.save_compiled_model = options.migraphx_save_compiled_model;
57+
info.save_model_file = "";
58+
if (options.migraphx_save_model_path != nullptr) {
59+
info.save_model_file = options.migraphx_save_model_path;
60+
}
61+
info.load_compiled_model = options.migraphx_load_compiled_model;
62+
info.load_model_file = "";
63+
if (options.migraphx_load_model_path != nullptr) {
64+
info.load_model_file = options.migraphx_load_model_path;
65+
}
5666
return std::make_shared<MIGraphXProviderFactory>(info);
5767
}
5868

@@ -79,6 +89,11 @@ struct MIGraphX_Provider : Provider {
7989
}
8090

8191
migx_options.migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table;
92+
93+
migx_options.migraphx_save_compiled_model = internal_options.save_compiled_model;
94+
migx_options.migraphx_save_model_path = internal_options.save_model_file.c_str();
95+
migx_options.migraphx_load_compiled_model = internal_options.load_compiled_model;
96+
migx_options.migraphx_load_model_path = internal_options.load_model_file.c_str();
8297
}
8398

8499
ProviderOptions GetProviderOptions(const void* provider_options) override {

onnxruntime/python/onnxruntime_pybind_state.cc

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -823,14 +823,20 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
823823
} else if (type == kMIGraphXExecutionProvider) {
824824
#ifdef USE_MIGRAPHX
825825
std::string calibration_table;
826+
std::string save_model_path;
827+
std::string load_model_path;
826828
auto it = provider_options_map.find(type);
827829
if (it != provider_options_map.end()) {
828830
OrtMIGraphXProviderOptions params{
829831
0,
830832
0,
831833
0,
832834
0,
833-
nullptr};
835+
nullptr,
836+
1,
837+
"./compiled_model.mxr",
838+
1,
839+
"./compiled_model.mxr"};
834840
for (auto option : it->second) {
835841
if (option.first == "device_id") {
836842
if (!option.second.empty()) {
@@ -877,6 +883,44 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
877883
"[ERROR] [MIGraphX] The value for the key 'migx_int8_use_native_calibration_table' should be"
878884
" 'True' or 'False'. Default value is 'False'.\n");
879885
}
886+
} else if (option.first == "migraphx_save_compiled_model") {
887+
if (option.second == "True" || option.second == "true") {
888+
params.migraphx_fp16_enable = true;
889+
} else if (option.second == "False" || option.second == "false") {
890+
params.migraphx_fp16_enable = false;
891+
} else {
892+
ORT_THROW(
893+
"[ERROR] [MIGraphX] The value for the key 'migx_save_compiled_model' should be"
894+
" 'True' or 'False'. Default value is 'False'.\n");
895+
}
896+
} else if (option.first == "migraphx_save_model_path") {
897+
if (!option.second.empty()) {
898+
save_model_path = option.second;
899+
params.migraphx_save_model_path = save_model_path.c_str();
900+
} else {
901+
ORT_THROW(
902+
"[ERROR] [MIGraphX] The value for the key 'migx_save_model_name' should be a "
903+
"file name i.e. 'compiled_model.mxr'.\n");
904+
}
905+
} else if (option.first == "migraphx_load_compiled_model") {
906+
if (option.second == "True" || option.second == "true") {
907+
params.migraphx_fp16_enable = true;
908+
} else if (option.second == "False" || option.second == "false") {
909+
params.migraphx_fp16_enable = false;
910+
} else {
911+
ORT_THROW(
912+
"[ERROR] [MIGraphX] The value for the key 'migx_load_compiled_model' should be"
913+
" 'True' or 'False'. Default value is 'False'.\n");
914+
}
915+
} else if (option.first == "migraphx_load_model_path") {
916+
if (!option.second.empty()) {
917+
load_model_path = option.second;
918+
params.migraphx_load_model_path = load_model_path.c_str();
919+
} else {
920+
ORT_THROW(
921+
"[ERROR] [MIGraphX] The value for the key 'migx_load_model_name' should be a "
922+
"file name i.e. 'compiled_model.mxr'.\n");
923+
}
880924
} else {
881925
ORT_THROW("Invalid MIGraphX EP option: ", option.first);
882926
}

onnxruntime/test/util/default_providers.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,11 @@ std::unique_ptr<IExecutionProvider> DefaultMIGraphXExecutionProvider() {
7676
0,
7777
0,
7878
0,
79-
nullptr};
79+
nullptr,
80+
1,
81+
"./compiled_model.mxr",
82+
1,
83+
"./compiled_model.mxr"};
8084
return MIGraphXProviderFactoryCreator::Create(&params)->CreateProvider();
8185
#else
8286
return nullptr;

0 commit comments

Comments
 (0)