Skip to content

Commit fbee484

Browse files
author
Ted Themistokleous
committed
[MIGraphX EP] Add migraphx ep save load compiles (microsoft#20643)
Adds the ability for MIGraphX EP to save off or load compiled models to save time between inferences. Via Command line User should be able to set the save ability with ORT_MIGRAPHX_SAVE_COMPILED_MODEL ORT_MIGRAPHX_SAVE_COMPILE_PATH User should be able to set the load ability with ORT_MIGRAPHX_LOAD_COMPILED_MODEL ORT_MIGRAPHX_LOAD_COMPILE_PATH via Onnxruntime API migx_save_compiled_model migx_save_model_name migx_load_compiled_model migx_load_model_name The motivation for this is to leverage MIGraphX's existing API to save/load models after our compile step of graph optimization. For larger models or models which were compiled with additional tuning steps, this saves time after first compile and inference run, and thus speeds up the user experience in order to encourage development. --------- Co-authored-by: Ted Themistokleous <[email protected]>
1 parent 69e234a commit fbee484

File tree

8 files changed

+245
-86
lines changed

8 files changed

+245
-86
lines changed

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,10 @@ typedef struct OrtMIGraphXProviderOptions {
608608
int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true
609609
int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true
610610
const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name
611+
int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, noznero = true
612+
const char* migraphx_save_model_path; // migraphx model path name
613+
int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true
614+
const char* migraphx_load_model_path; // migraphx model path name
611615
} OrtMIGraphXProviderOptions;
612616

613617
/** \brief OpenVINO Provider Options

onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

Lines changed: 148 additions & 83 deletions
Large diffs are not rendered by default.

onnxruntime/core/providers/migraphx/migraphx_execution_provider.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS";
2121
static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME";
2222
static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH";
2323
static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE";
24+
static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL";
25+
static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILE_PATH";
26+
static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL";
27+
static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILE_PATH";
28+
2429
}; // namespace migraphx_env_vars
2530

2631
// Information to construct kernel function state.
@@ -39,6 +44,10 @@ struct MIGraphXFuncState {
3944
bool int8_enable = false;
4045
bool int8_calibration_cache_available = false;
4146
std::unordered_map<std::string, float> dynamic_range_map;
47+
bool save_compiled_mode = false;
48+
std::string save_compiled_path;
49+
bool load_compiled_mode = false;
50+
std::string load_compiled_path;
4251
bool dump_model_ops = false;
4352
};
4453

@@ -82,7 +91,11 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
8291
bool int8_calibration_cache_available_ = false;
8392
bool int8_use_native_migraphx_calibration_table_ = false;
8493
std::string calibration_cache_path_;
85-
std::unordered_map<std::string, float> dynamic_range_map_;
94+
std::unordered_map<std::string, float> dynamic_range_map;
95+
bool save_compiled_model_ = false;
96+
std::string save_compiled_path_;
97+
bool load_compiled_model_ = false;
98+
std::string load_compiled_path_;
8699
bool dump_model_ops_ = false;
87100
migraphx::target t_;
88101
OrtMutex mgx_mu_;

onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ constexpr const char* kFp16Enable = "trt_fp16_enable";
1717
constexpr const char* kInt8Enable = "migx_int8_enable";
1818
constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name";
1919
constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table";
20+
constexpr const char* kSaveCompiledModel = "migx_save_compiled_model";
21+
constexpr const char* kSaveModelPath = "migx_save_model_name";
22+
constexpr const char* kLoadCompiledModel = "migx_load_compiled_model";
23+
constexpr const char* kLoadModelPath = "migx_load_model_name";
2024

2125
} // namespace provider_option_names
2226
} // namespace migraphx
@@ -39,6 +43,8 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions
3943
})
4044
.AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable)
4145
.AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable)
46+
.AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model)
47+
.AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model)
4248
.Parse(options));
4349

4450
return info;
@@ -49,6 +55,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE
4955
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
5056
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)},
5157
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
58+
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)},
59+
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)},
5260
};
5361
return options;
5462
}
@@ -58,6 +66,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap
5866
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
5967
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)},
6068
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
69+
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)},
70+
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)},
6171
};
6272
return options;
6373
}

onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ struct MIGraphXExecutionProviderInfo {
1919
bool int8_enable{false};
2020
std::string int8_calibration_table_name{""};
2121
bool int8_use_native_calibration_table{false};
22+
bool save_compiled_model{true};
23+
std::string save_model_file{"./compiled_model.mxr"};
24+
bool load_compiled_model{true};
25+
std::string load_model_file{"./compiled_model.mxr"};
2226

2327
static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
2428
static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info);

onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,16 @@ struct MIGraphX_Provider : Provider {
6666
info.int8_calibration_table_name = options.migraphx_int8_calibration_table_name;
6767
}
6868
info.int8_use_native_calibration_table = options.migraphx_use_native_calibration_table != 0;
69+
info.save_compiled_model = options.migraphx_save_compiled_model;
70+
info.save_model_file = "";
71+
if (options.migraphx_save_model_path != nullptr) {
72+
info.save_model_file = options.migraphx_save_model_path;
73+
}
74+
info.load_compiled_model = options.migraphx_load_compiled_model;
75+
info.load_model_file = "";
76+
if (options.migraphx_load_model_path != nullptr) {
77+
info.load_model_file = options.migraphx_load_model_path;
78+
}
6979
return std::make_shared<MIGraphXProviderFactory>(info);
7080
}
7181

@@ -92,6 +102,11 @@ struct MIGraphX_Provider : Provider {
92102
}
93103

94104
migx_options.migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table;
105+
106+
migx_options.migraphx_save_compiled_model = internal_options.save_compiled_model;
107+
migx_options.migraphx_save_model_path = internal_options.save_model_file.c_str();
108+
migx_options.migraphx_load_compiled_model = internal_options.load_compiled_model;
109+
migx_options.migraphx_load_model_path = internal_options.load_model_file.c_str();
95110
}
96111

97112
ProviderOptions GetProviderOptions(const void* provider_options) override {

onnxruntime/python/onnxruntime_pybind_state.cc

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -763,14 +763,20 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
763763
} else if (type == kMIGraphXExecutionProvider) {
764764
#ifdef USE_MIGRAPHX
765765
std::string calibration_table;
766+
std::string save_model_path;
767+
std::string load_model_path;
766768
auto it = provider_options_map.find(type);
767769
if (it != provider_options_map.end()) {
768770
OrtMIGraphXProviderOptions params{
769771
0,
770772
0,
771773
0,
772774
0,
773-
nullptr};
775+
nullptr,
776+
1,
777+
"./compiled_model.mxr",
778+
1,
779+
"./compiled_model.mxr"};
774780
for (auto option : it->second) {
775781
if (option.first == "device_id") {
776782
if (!option.second.empty()) {
@@ -817,6 +823,44 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
817823
"[ERROR] [MIGraphX] The value for the key 'migx_int8_use_native_calibration_table' should be"
818824
" 'True' or 'False'. Default value is 'False'.\n");
819825
}
826+
} else if (option.first == "migraphx_save_compiled_model") {
827+
if (option.second == "True" || option.second == "true") {
828+
params.migraphx_fp16_enable = true;
829+
} else if (option.second == "False" || option.second == "false") {
830+
params.migraphx_fp16_enable = false;
831+
} else {
832+
ORT_THROW(
833+
"[ERROR] [MIGraphX] The value for the key 'migx_save_compiled_model' should be"
834+
" 'True' or 'False'. Default value is 'False'.\n");
835+
}
836+
} else if (option.first == "migraphx_save_model_path") {
837+
if (!option.second.empty()) {
838+
save_model_path = option.second;
839+
params.migraphx_save_model_path = save_model_path.c_str();
840+
} else {
841+
ORT_THROW(
842+
"[ERROR] [MIGraphX] The value for the key 'migx_save_model_name' should be a "
843+
"file name i.e. 'compiled_model.mxr'.\n");
844+
}
845+
} else if (option.first == "migraphx_load_compiled_model") {
846+
if (option.second == "True" || option.second == "true") {
847+
params.migraphx_fp16_enable = true;
848+
} else if (option.second == "False" || option.second == "false") {
849+
params.migraphx_fp16_enable = false;
850+
} else {
851+
ORT_THROW(
852+
"[ERROR] [MIGraphX] The value for the key 'migx_load_compiled_model' should be"
853+
" 'True' or 'False'. Default value is 'False'.\n");
854+
}
855+
} else if (option.first == "migraphx_load_model_path") {
856+
if (!option.second.empty()) {
857+
load_model_path = option.second;
858+
params.migraphx_load_model_path = load_model_path.c_str();
859+
} else {
860+
ORT_THROW(
861+
"[ERROR] [MIGraphX] The value for the key 'migx_load_model_name' should be a "
862+
"file name i.e. 'compiled_model.mxr'.\n");
863+
}
820864
} else {
821865
ORT_THROW("Invalid MIGraphX EP option: ", option.first);
822866
}

onnxruntime/test/util/default_providers.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,11 @@ std::unique_ptr<IExecutionProvider> DefaultMIGraphXExecutionProvider() {
7676
0,
7777
0,
7878
0,
79-
nullptr};
79+
nullptr,
80+
1,
81+
"./compiled_model.mxr",
82+
1,
83+
"./compiled_model.mxr"};
8084
return MIGraphXProviderFactoryCreator::Create(&params)->CreateProvider();
8185
#else
8286
return nullptr;

0 commit comments

Comments
 (0)