Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,10 @@ typedef struct OrtMIGraphXProviderOptions {
int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true
int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true
const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name
int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, noznero = true
const char* migraphx_save_model_path; // migraphx model path name
int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true
const char* migraphx_load_model_path; // migraphx model path name
} OrtMIGraphXProviderOptions;

/** \brief OpenVINO Provider Options
Expand Down
205 changes: 159 additions & 46 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS";
static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME";
static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH";
static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE";
static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL";
static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILE_PATH";
static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL";
static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILE_PATH";

}; // namespace migraphx_env_vars

// Information to construct kernel function state.
Expand All @@ -44,6 +49,10 @@ struct MIGraphXFuncState {
bool int8_enable = false;
bool int8_calibration_cache_available = false;
std::unordered_map<std::string, float> dynamic_range_map;
bool save_compiled_mode = false;
std::string save_compiled_path;
bool load_compiled_mode = false;
std::string load_compiled_path;
bool dump_model_ops = false;
};

Expand Down Expand Up @@ -84,6 +93,10 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
bool int8_use_native_migraphx_calibration_table_ = false;
std::string calibration_cache_path_;
std::unordered_map<std::string, float> dynamic_range_map;
bool save_compiled_model_ = false;
std::string save_compiled_path_;
bool load_compiled_model_ = false;
std::string load_compiled_path_;
bool dump_model_ops_ = false;
int device_id_;
migraphx::target t_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ constexpr const char* kFp16Enable = "trt_fp16_enable";
constexpr const char* kInt8Enable = "migx_int8_enable";
constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name";
constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table";
constexpr const char* kSaveCompiledModel = "migx_save_compiled_model";
constexpr const char* kSaveModelPath = "migx_save_model_name";
constexpr const char* kLoadCompiledModel = "migx_load_compiled_model";
constexpr const char* kLoadModelPath = "migx_load_model_name";

} // namespace provider_option_names
} // namespace migraphx
Expand All @@ -39,6 +43,8 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions
})
.AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model)
.AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model)
.Parse(options));

return info;
Expand All @@ -49,6 +55,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)},
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)},
};
return options;
}
Expand All @@ -58,6 +66,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)},
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)},
};
return options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ struct MIGraphXExecutionProviderInfo {
bool int8_enable{false};
std::string int8_calibration_table_name{""};
bool int8_use_native_calibration_table{false};
bool save_compiled_model{true};
std::string save_model_file{"./compiled_model.mxr"};
bool load_compiled_model{true};
std::string load_model_file{"./compiled_model.mxr"};

static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info);
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ struct MIGraphX_Provider : Provider {
info.int8_calibration_table_name = options.migraphx_int8_calibration_table_name;
}
info.int8_use_native_calibration_table = options.migraphx_use_native_calibration_table != 0;
info.save_compiled_model = options.migraphx_save_compiled_model;
info.save_model_file = "";
if (options.migraphx_save_model_path != nullptr) {
info.save_model_file = options.migraphx_save_model_path;
}
info.load_compiled_model = options.migraphx_load_compiled_model;
info.load_model_file = "";
if (options.migraphx_load_model_path != nullptr) {
info.load_model_file = options.migraphx_load_model_path;
}
return std::make_shared<MIGraphXProviderFactory>(info);
}

Expand All @@ -79,6 +89,11 @@ struct MIGraphX_Provider : Provider {
}

migx_options.migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table;

migx_options.migraphx_save_compiled_model = internal_options.save_compiled_model;
migx_options.migraphx_save_model_path = internal_options.save_model_file.c_str();
migx_options.migraphx_load_compiled_model = internal_options.load_compiled_model;
migx_options.migraphx_load_model_path = internal_options.load_model_file.c_str();
}

ProviderOptions GetProviderOptions(const void* provider_options) override {
Expand Down
46 changes: 45 additions & 1 deletion onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -823,14 +823,20 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
} else if (type == kMIGraphXExecutionProvider) {
#ifdef USE_MIGRAPHX
std::string calibration_table;
std::string save_model_path;
std::string load_model_path;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
OrtMIGraphXProviderOptions params{
0,
0,
0,
0,
nullptr};
nullptr,
1,
"./compiled_model.mxr",
1,
"./compiled_model.mxr"};
for (auto option : it->second) {
if (option.first == "device_id") {
if (!option.second.empty()) {
Expand Down Expand Up @@ -877,6 +883,44 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
"[ERROR] [MIGraphX] The value for the key 'migx_int8_use_native_calibration_table' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_save_compiled_model") {
if (option.second == "True" || option.second == "true") {
params.migraphx_fp16_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_fp16_enable = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_save_compiled_model' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_save_model_path") {
if (!option.second.empty()) {
save_model_path = option.second;
params.migraphx_save_model_path = save_model_path.c_str();
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_save_model_name' should be a "
"file name i.e. 'compiled_model.mxr'.\n");
}
} else if (option.first == "migraphx_load_compiled_model") {
if (option.second == "True" || option.second == "true") {
params.migraphx_fp16_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_fp16_enable = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_load_compiled_model' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_load_model_path") {
if (!option.second.empty()) {
load_model_path = option.second;
params.migraphx_load_model_path = load_model_path.c_str();
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_load_model_name' should be a "
"file name i.e. 'compiled_model.mxr'.\n");
}
} else {
ORT_THROW("Invalid MIGraphX EP option: ", option.first);
}
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/test/util/default_providers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ std::unique_ptr<IExecutionProvider> DefaultMIGraphXExecutionProvider() {
0,
0,
0,
nullptr};
nullptr,
1,
"./compiled_model.mxr",
1,
"./compiled_model.mxr"};
return MIGraphXProviderFactoryCreator::Create(&params)->CreateProvider();
#else
return nullptr;
Expand Down