Skip to content

Commit d8d4bd6

Browse files
author
Ted Themistokleous
committed
Add env flags for save/load of compiled migx models
Initial changes for the getting the flags into the state to save/load precompiled MIGraphX models in the EP.
1 parent 69cfcba commit d8d4bd6

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,28 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
153153
}
154154
}
155155

156+
//Save/load migraphx compiled models
157+
const std::string save_comp_model_env = onnxruntime::GetenvironmentVar(migraphx_env_vars::kSaveCompiledModel);
158+
if (!save_comp_model_env.empty()) {
159+
save_compiled_model_ = (std::stoi(save_comp_model_env) == 0 ? false : true);
160+
}
161+
162+
const std::string save_model_path_env = onnxruntime::GetenvironmentVar(migraphx_env_vars::ksaveCompiledPath);
163+
164+
if (save_compiled_model_ && !save_model_path_env.empty()) {
165+
save_compiled_path_ = save_model_path_env;
166+
}
167+
168+
const std::string load_comp_model_env = onnxruntime::GetenvironmentVar(migraphx_env_vars::kLoadCompiledModel);
169+
if (!load_comp_model_env.empty()) {
170+
load_compiled_model_ = (std::stoi(load_comp_model_env) == 0 ? false : true);
171+
}
172+
173+
const std::string load_model_path_env = onnxruntime::GetenvironmentVar(migraphx_env_vars::kLoadCompiledPath);
174+
if (load_compiled_model_ && !load_model_path_env.empty()) {
175+
load_compiled_path_ = load_model_path_env;
176+
}
177+
156178
// dump unsupported ops
157179
const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::dumpModelOps);
158180
if (!dump_model_ops_env.empty()) {
@@ -171,10 +193,15 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
171193
<< "device_id: " << device_id_
172194
<< ", migraphx_fp16_enable: " << fp16_enable_
173195
<< ", migraphx_int8_enable: " << int8_enable_
196+
<< ", migraphx_int8_enable: " << int8_enable_
174197
<< ", dump_model_ops: " << dump_model_ops_
175198
<< ", migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_
176199
<< ", int8_calibration_cache_available: " << int8_calibration_cache_available_
177-
<< ", use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_;
200+
<< ", use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_
201+
<< ", migraphx_save_compiled_model: " << save_compiled_model_
202+
<< ", migraphx_save_compiled_model_path: " << save_compiled_model_path_
203+
<< ", migraphx_load_compiled_model: " << load_compiled_model_
204+
<< ", migraphx_load_compiled_model_path: " << load_compiled_model_path_;
178205
}
179206

180207
MIGraphXExecutionProvider::~MIGraphXExecutionProvider() {

onnxruntime/core/providers/migraphx/migraphx_execution_provider.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS";
2626
static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME";
2727
static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH";
2828
static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE";
29+
static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL";
30+
static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILE_PATH";
31+
static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL";
32+
static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILE_PATH";
33+
2934
}; // namespace migraphx_env_vars
3035

3136
// Information to construct kernel function state.
@@ -44,6 +49,10 @@ struct MIGraphXFuncState {
4449
bool int8_enable = false;
4550
bool int8_calibration_cache_available = false;
4651
std::unordered_map<std::string, float> dynamic_range_map;
52+
bool save_compiled_mode = false;
53+
std::string save_compiled_path = false;
54+
bool load_compiled_mode = false;
55+
std::string load_compiled_path = false;
4756
bool dump_model_ops = false;
4857
};
4958

@@ -84,6 +93,10 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
8493
bool int8_use_native_migraphx_calibration_table_ = false;
8594
std::string calibration_cache_path_;
8695
std::unordered_map<std::string, float> dynamic_range_map;
96+
bool save_compiled_model_ = false;
97+
std::string save_compiled_path_;
98+
bool load_compiled_model_ = false;
99+
std::string load_compiled_path_;
87100
bool dump_model_ops_ = false;
88101
int device_id_;
89102
migraphx::target t_;

0 commit comments

Comments
 (0)