diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 16701f2e0d923..5c61963a2f39c 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -617,6 +617,10 @@ typedef struct OrtMIGraphXProviderOptions { int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name + int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, noznero = true + const char* migraphx_save_model_path; // migraphx model path name + int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true + const char* migraphx_load_model_path; // migraphx model path name } OrtMIGraphXProviderOptions; /** \brief OpenVINO Provider Options diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 9acbb9c17ec36..581376623ffe0 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -153,6 +153,28 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv } } + // Save/load migraphx compiled models + const std::string save_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSaveCompiledModel); + if (!save_comp_model_env.empty()) { + save_compiled_model_ = (std::stoi(save_comp_model_env) == 0 ? false : true); + } + + const std::string save_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSavedModelPath); + + if (save_compiled_model_ && !save_model_path_env.empty()) { + save_compiled_path_ = save_model_path_env; + } + + const std::string load_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadCompiledModel); + if (!load_comp_model_env.empty()) { + load_compiled_model_ = (std::stoi(load_comp_model_env) == 0 ? false : true); + } + + const std::string load_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadModelPath); + if (load_compiled_model_ && !load_model_path_env.empty()) { + load_compiled_path_ = load_model_path_env; + } + // dump unsupported ops const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::dumpModelOps); if (!dump_model_ops_env.empty()) { @@ -171,10 +193,15 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv << "device_id: " << device_id_ << ", migraphx_fp16_enable: " << fp16_enable_ << ", migraphx_int8_enable: " << int8_enable_ + << ", migraphx_int8_enable: " << int8_enable_ << ", dump_model_ops: " << dump_model_ops_ << ", migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_ << ", int8_calibration_cache_available: " << int8_calibration_cache_available_ - << ", use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_; + << ", use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_ + << ", migraphx_save_compiled_model: " << save_compiled_model_ + << ", migraphx_save_compiled_model_path: " << save_compiled_path_ + << ", migraphx_load_compiled_model: " << load_compiled_model_ + << ", migraphx_load_compiled_model_path: " << load_compiled_path_; } MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { @@ -265,7 +292,7 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, break; default: LOGS_DEFAULT(WARNING) << "MiGraphx: unsupported data type " << type << ", fallback to CPU"; - LOGS_DEFAULT(WARNING) << "implementation" << std::endl; + LOGS_DEFAULT(WARNING) << "implementation"; return false; } @@ -1008,11 +1035,11 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v result.push_back(ComputeCapability::Create(std::move(sub_graph))); } else { // unsupported_nodes_idx.empty() if (dump_model_ops_) { - LOGS_DEFAULT(INFO) << "============= Unsupported nodes ====================" << std::endl; + LOGS_DEFAULT(INFO) << "============= Unsupported nodes ===================="; for (auto idx : unsupported_nodes) { LOGS_DEFAULT(INFO) << graph_viewer.GetNode(idx)->OpType() << std::endl; } - LOGS_DEFAULT(INFO) << "************* Unsupported nodes ********************" << std::endl; + LOGS_DEFAULT(INFO) << "************* Unsupported nodes ********************"; } if (unsupported_nodes.size() > 10) { @@ -1087,6 +1114,34 @@ bool get_input_output_names(const GraphViewer& graph, return no_input_shape; } +// Attempt to load a model and catch any exceptions on load fail. +// Useful to default to EP to trigger the compile if file doesn't exist or loading fails. +bool load_precompiled_model(migraphx::program& prog, bool load_enable, std::string path) { + try { + if (load_enable) { + LOGS_DEFAULT(INFO) << "Attempting to load model at:" << path; + prog = migraphx::load(path.c_str()); + LOGS_DEFAULT(INFO) << "load model : Success"; + return true; + } else { + return false; + } + } catch (...) { + return false; + } + return false; +} + +void save_compiled_model(migraphx::program& prog, bool save_enable, std::string out_path) { + if (save_enable) { + LOGS_DEFAULT(INFO) << "Model Save at " << out_path << ": Begin" << std::endl; + migraphx::file_options fo; + fo.set_file_format("msgpack"); + migraphx::save(prog, out_path.c_str(), fo); + LOGS_DEFAULT(INFO) << "Model Save: Complete" << std::endl; + } +} + Status MIGraphXExecutionProvider::Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) { migraphx::onnx_options options; @@ -1117,39 +1172,56 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& } std::vector input_names, output_names; - no_input_shape = get_input_output_names(graph_body_viewer, input_names, output_names); + no_input_shape = no_input_shape or get_input_output_names(graph_body_viewer, input_names, output_names); // by parsing the model_proto, create a program corresponding to // the input fused_node migraphx::program prog; if (!no_input_shape) { - prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); - if (fp16_enable_) { - migraphx::quantize_fp16(prog); - } + if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { + LOGS_DEFAULT(INFO) << "No Input shapes detected quantizing model"; + prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); - // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if (int8_enable_ && int8_calibration_cache_available_) { - migraphx::quantize_int8_options quant_opts; - migraphx::program_parameters quant_params; + // Read in the calibration data and map it to an migraphx paramater map for the calibration ops + if (int8_enable_ && int8_calibration_cache_available_) { + LOGS_DEFAULT(INFO) << "Quantizing input program to int8" << std::endl; + migraphx::quantize_int8_options quant_opts; + migraphx::program_parameters quant_params; - auto param_shapes = prog.get_parameter_shapes(); + auto param_shapes = prog.get_parameter_shapes(); - for (auto&& name : param_shapes.names()) { - auto dynamic_range_i = dynamic_range_map.find(name); - if (dynamic_range_i != dynamic_range_map.end()) { - quant_params.add(name, migraphx::argument(param_shapes[name], &(dynamic_range_i->second))); + // Add all calibration data read in from int8 table + for (auto& [cal_key, cal_val] : dynamic_range_map) { + auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); + quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); } + quant_opts.add_calibration_data(quant_params); + + // specify thing we want to int8 quantize + quant_opts.add_op_name("convolution"); + quant_opts.add_op_name("dot"); + + // perform static quantization on the programs + migraphx::quantize_int8(prog, t_, quant_opts); + LOGS_DEFAULT(INFO) << "Quantizing input program to int8: Complete" << std::endl; + } + + if (fp16_enable_) { + LOGS_DEFAULT(INFO) << "Quantizing input program to fp16" << std::endl; + migraphx::quantize_fp16(prog); + LOGS_DEFAULT(INFO) << "Quantizing input program to fp16: Complete" << std::endl; } - quant_opts.add_calibration_data(quant_params); - // perform static quantization on the programs - migraphx::quantize_int8(prog, t_, quant_opts); + migraphx::compile_options co; + co.set_fast_math(false); + LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl; + prog.compile(t_, co); + LOGS_DEFAULT(INFO) << "Model Compile: Complete" << std::endl; + + save_compiled_model(prog, save_compiled_model_, save_compiled_path_); } - migraphx::compile_options co; - co.set_fast_math(false); - prog.compile(t_, co); + auto prog_output_shapes = prog.get_output_shapes(); for (std::size_t i = 0; i < output_names.size(); ++i) { auto out_len = prog_output_shapes[i].lengths(); @@ -1169,7 +1241,9 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, map_no_input_shape_[context->node_name], fp16_enable_, int8_enable_, - int8_calibration_cache_available_, dynamic_range_map, dump_model_ops_}; + int8_calibration_cache_available_, dynamic_range_map, + save_compiled_model_, save_compiled_path_, + load_compiled_model_, load_compiled_path_, dump_model_ops_}; *state = p.release(); return 0; }; @@ -1199,6 +1273,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& bool input_shape_match = true; migraphx::program_parameter_shapes param_shapes; if (no_input_shape) { + LOGS_DEFAULT(VERBOSE) << "Missing input shape setting input parameters again" << std::endl; for (auto& it : map_input_name_index) { auto& name = it.first; auto& index = it.second; @@ -1210,6 +1285,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& input_shape_match = false; } } else { + LOGS_DEFAULT(VERBOSE) << "Assigning inputs, and parameters from compiled model" << std::endl; param_shapes = prog.get_parameter_shapes(); auto prog_output_shapes = prog.get_output_shapes(); @@ -1243,33 +1319,67 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // input shapes are different, needs to re-parse onnx and // re-compile the program if (!input_shape_match) { - prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); - if (fp16_enable) { - migraphx::quantize_fp16(prog); - } + if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { + LOGS_DEFAULT(VERBOSE) << "No Input shapes mismatch detected. Recompiling" << std::endl; + prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); + + // Read in the calibration data and map it to an migraphx paramater map for the calibration ops + if (int8_enable && int8_calibration_cache_available) { + LOGS_DEFAULT(INFO) << "Quantize Int8: Begin" << std::endl; + migraphx::quantize_int8_options quant_opts; + migraphx::program_parameters quant_params; + + auto param_shapes = prog.get_parameter_shapes(); + + // Add input parameter data and the values they're set to + for (auto&& name : param_shapes.names()) { + if (map_input_name_index.count(name) > 0) { + auto input_tensor = ctx.GetInput(map_input_name_index[name]); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shape = tensor_info.GetShape(); + const auto tensor_type = tensor_info.GetElementType(); + + migraphx_shape_datatype_t mgx_type; + getMIGraphXType(tensor_type, mgx_type); + auto mgx_s = param_shapes[name]; + + if (mgx_type != mgx_s.type()) { + LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; + } + quant_params.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); + } + } - // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if (int8_enable && int8_calibration_cache_available) { - migraphx::quantize_int8_options quant_opts; - migraphx::program_parameters quant_params; + // Add all calibration data read in from int8 table + for (auto& [cal_key, cal_val] : map_dynamic_range) { + auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); + quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); + } + quant_opts.add_calibration_data(quant_params); - auto param_shapes = prog.get_parameter_shapes(); + // specify thing we want to int8 quantize + quant_opts.add_op_name("convolution"); + quant_opts.add_op_name("dot"); - for (auto&& name : param_shapes.names()) { - auto dynamic_range_i = map_dynamic_range.find(name); - if (dynamic_range_i != map_dynamic_range.end()) { - quant_params.add(name, migraphx::argument(param_shapes[name], &(dynamic_range_i->second))); - } + // perform static quantization on the programs + migraphx::quantize_int8(prog, t, quant_opts); + LOGS_DEFAULT(INFO) << "Quantize Int8: Completed" << std::endl; } - quant_opts.add_calibration_data(quant_params); - // perform static quantization on the programs - migraphx::quantize_int8(prog, t, quant_opts); + if (fp16_enable) { + LOGS_DEFAULT(INFO) << "Quantize fp16: Begin" << std::endl; + migraphx::quantize_fp16(prog); + LOGS_DEFAULT(INFO) << "Quantize fp16: Completed" << std::endl; + } + + LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl; + migraphx::compile_options co; + co.set_fast_math(false); + prog.compile(t, co); + + save_compiled_model(prog, mgx_state->save_compiled_mode, mgx_state->save_compiled_path); } - migraphx::compile_options co; - co.set_fast_math(false); - prog.compile(t, co); mgx_state->prog = prog; param_shapes = prog.get_parameter_shapes(); no_input_shape = false; @@ -1281,6 +1391,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (param_shapes.size() > 0) { for (auto&& name : param_shapes.names()) { if (map_input_name_index.count(name) > 0) { + LOGS_DEFAULT(INFO) << "Setting parameters for:" << name << std::endl; auto input_tensor = ctx.GetInput(map_input_name_index[name]); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shape = tensor_info.GetShape(); @@ -1293,6 +1404,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (mgx_type != mgx_s.type()) { LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; } + + LOGS_DEFAULT(INFO) << "Writing Raw tensor data " << std::endl; m.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); } @@ -1353,7 +1466,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& HIP_CALL_THROW(hipMemcpy(output_data, gpu_res.data(), res_shape.bytes(), hipMemcpyDeviceToDevice)); } } - } + }; return Status::OK(); }; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index c3617f409e72c..1977f71b8b1cf 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -26,6 +26,11 @@ static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"; static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH"; static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"; +static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL"; +static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILE_PATH"; +static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL"; +static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILE_PATH"; + }; // namespace migraphx_env_vars // Information to construct kernel function state. @@ -44,6 +49,10 @@ struct MIGraphXFuncState { bool int8_enable = false; bool int8_calibration_cache_available = false; std::unordered_map dynamic_range_map; + bool save_compiled_mode = false; + std::string save_compiled_path; + bool load_compiled_mode = false; + std::string load_compiled_path; bool dump_model_ops = false; }; @@ -84,6 +93,10 @@ class MIGraphXExecutionProvider : public IExecutionProvider { bool int8_use_native_migraphx_calibration_table_ = false; std::string calibration_cache_path_; std::unordered_map dynamic_range_map; + bool save_compiled_model_ = false; + std::string save_compiled_path_; + bool load_compiled_model_ = false; + std::string load_compiled_path_; bool dump_model_ops_ = false; int device_id_; migraphx::target t_; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index b7d7a77853df6..2a135b7324f3a 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -17,6 +17,10 @@ constexpr const char* kFp16Enable = "trt_fp16_enable"; constexpr const char* kInt8Enable = "migx_int8_enable"; constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name"; constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table"; +constexpr const char* kSaveCompiledModel = "migx_save_compiled_model"; +constexpr const char* kSaveModelPath = "migx_save_model_name"; +constexpr const char* kLoadCompiledModel = "migx_load_compiled_model"; +constexpr const char* kLoadModelPath = "migx_load_model_name"; } // namespace provider_option_names } // namespace migraphx @@ -39,6 +43,8 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions }) .AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable) .AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable) + .AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model) + .AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model) .Parse(options)); return info; @@ -49,6 +55,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, + {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)}, + {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)}, }; return options; } @@ -58,6 +66,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, + {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)}, + {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)}, }; return options; } diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index 18ac30fdc1283..8411e3eef096b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -19,6 +19,10 @@ struct MIGraphXExecutionProviderInfo { bool int8_enable{false}; std::string int8_calibration_table_name{""}; bool int8_use_native_calibration_table{false}; + bool save_compiled_model{true}; + std::string save_model_file{"./compiled_model.mxr"}; + bool load_compiled_model{true}; + std::string load_model_file{"./compiled_model.mxr"}; static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info); diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index f985682ddc735..dd24dbdc76d2f 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -53,6 +53,16 @@ struct MIGraphX_Provider : Provider { info.int8_calibration_table_name = options.migraphx_int8_calibration_table_name; } info.int8_use_native_calibration_table = options.migraphx_use_native_calibration_table != 0; + info.save_compiled_model = options.migraphx_save_compiled_model; + info.save_model_file = ""; + if (options.migraphx_save_model_path != nullptr) { + info.save_model_file = options.migraphx_save_model_path; + } + info.load_compiled_model = options.migraphx_load_compiled_model; + info.load_model_file = ""; + if (options.migraphx_load_model_path != nullptr) { + info.load_model_file = options.migraphx_load_model_path; + } return std::make_shared(info); } @@ -79,6 +89,11 @@ struct MIGraphX_Provider : Provider { } migx_options.migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table; + + migx_options.migraphx_save_compiled_model = internal_options.save_compiled_model; + migx_options.migraphx_save_model_path = internal_options.save_model_file.c_str(); + migx_options.migraphx_load_compiled_model = internal_options.load_compiled_model; + migx_options.migraphx_load_model_path = internal_options.load_model_file.c_str(); } ProviderOptions GetProviderOptions(const void* provider_options) override { diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 7f9a6e13d7864..c59aa35577a97 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -823,6 +823,8 @@ std::unique_ptr CreateExecutionProviderInstance( } else if (type == kMIGraphXExecutionProvider) { #ifdef USE_MIGRAPHX std::string calibration_table; + std::string save_model_path; + std::string load_model_path; auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { OrtMIGraphXProviderOptions params{ @@ -830,7 +832,11 @@ std::unique_ptr CreateExecutionProviderInstance( 0, 0, 0, - nullptr}; + nullptr, + 1, + "./compiled_model.mxr", + 1, + "./compiled_model.mxr"}; for (auto option : it->second) { if (option.first == "device_id") { if (!option.second.empty()) { @@ -877,6 +883,44 @@ std::unique_ptr CreateExecutionProviderInstance( "[ERROR] [MIGraphX] The value for the key 'migx_int8_use_native_calibration_table' should be" " 'True' or 'False'. Default value is 'False'.\n"); } + } else if (option.first == "migraphx_save_compiled_model") { + if (option.second == "True" || option.second == "true") { + params.migraphx_fp16_enable = true; + } else if (option.second == "False" || option.second == "false") { + params.migraphx_fp16_enable = false; + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migx_save_compiled_model' should be" + " 'True' or 'False'. Default value is 'False'.\n"); + } + } else if (option.first == "migraphx_save_model_path") { + if (!option.second.empty()) { + save_model_path = option.second; + params.migraphx_save_model_path = save_model_path.c_str(); + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migx_save_model_name' should be a " + "file name i.e. 'compiled_model.mxr'.\n"); + } + } else if (option.first == "migraphx_load_compiled_model") { + if (option.second == "True" || option.second == "true") { + params.migraphx_fp16_enable = true; + } else if (option.second == "False" || option.second == "false") { + params.migraphx_fp16_enable = false; + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migx_load_compiled_model' should be" + " 'True' or 'False'. Default value is 'False'.\n"); + } + } else if (option.first == "migraphx_load_model_path") { + if (!option.second.empty()) { + load_model_path = option.second; + params.migraphx_load_model_path = load_model_path.c_str(); + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migx_load_model_name' should be a " + "file name i.e. 'compiled_model.mxr'.\n"); + } } else { ORT_THROW("Invalid MIGraphX EP option: ", option.first); } diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index e353cc73b2986..6f07385729555 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -76,7 +76,11 @@ std::unique_ptr DefaultMIGraphXExecutionProvider() { 0, 0, 0, - nullptr}; + nullptr, + 1, + "./compiled_model.mxr", + 1, + "./compiled_model.mxr"}; return MIGraphXProviderFactoryCreator::Create(¶ms)->CreateProvider(); #else return nullptr;