From a20104c56a6174164ed9701205a603aa1cbc2cd9 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Mon, 28 Jul 2025 22:16:09 +0200 Subject: [PATCH 01/46] Add in BF16 datatype and quantization support to MIGraphX EP --- .../core/session/onnxruntime_c_api.h | 2 +- .../migraphx/migraphx_execution_provider.cc | 48 +++++++++++++++++-- .../migraphx/migraphx_execution_provider.h | 3 ++ .../migraphx_execution_provider_info.cc | 4 ++ .../migraphx_execution_provider_info.h | 5 +- .../migraphx/migraphx_provider_factory.cc | 2 + .../python/onnxruntime_pybind_state.cc | 10 ++++ 7 files changed, 69 insertions(+), 5 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index d87e9e083185b..a5ed5917bd53d 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -773,7 +773,7 @@ typedef struct OrtMIGraphXProviderOptions { * \note If a ::OrtArenaCfg has been applied, it will override this field */ int migraphx_arena_extend_strategy; - + int migraphx_bf16_enable; // MIGraphX BF16 precision. Default 0 = false, nonzero = true } OrtMIGraphXProviderOptions; /** \brief OpenVINO Provider Options diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 41b55e3baf508..e9648548781fa 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -127,6 +127,16 @@ void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecut // Quantization fp16_enable_ = info.fp16_enable; +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4 && HIP_VERSION_PATCH >= 2) + bf16_enable_ = info.bf16_enable; +#endif + + if (bf16_enable_ and fp16_enable_) { + bf16_enable_ = false; + fp16_enable_ = false; + LOGS_DEFAULT(FATAL) << "MIGraphX: BF16 and FP16 Quantization Mutually exclusive. Ignoring both Quantization flags"; + } + #if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4) fp8_enable_ = info.fp8_enable; #else @@ -136,6 +146,8 @@ void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecut int8_enable_ = info.int8_enable; if (int8_enable_ and fp8_enable_) { + int8_enable_ = false; + fp8_enable_ = false; LOGS_DEFAULT(FATAL) << "MIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; } @@ -178,6 +190,21 @@ void MIGraphXExecutionProvider::get_flags_from_env() { LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP16_ENABLE: " << fp16_enable_; } + const std::string bf16_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kBF16Enable); + if (!bf16_enable_env.empty()) { +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4 && HIP_VERSION_PATCH >= 2) + bf16_enable_ = (std::stoi(bf16_enable_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_BF16_ENABLE: " << fp16_enable_; +#else + LOGS_DEFAULT(WARNING) << "MIGraphX: BF16 Quantization requires ROCm 6.4.2 or greater"; + bf16_enable_ = false; +#endif + } + + if (bf16_enable_ and fp16_enable_) { + LOGS_DEFAULT(FATAL) << "\nMIGraphX: FP16 and BF16 Quantization Mutually exclusive. Ignoring both flags"; + } + // whether fp8 quantization is enabled const std::string fp8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP8Enable); if (!fp8_enable_env.empty()) { @@ -281,6 +308,7 @@ void MIGraphXExecutionProvider::get_flags_from_env() { void MIGraphXExecutionProvider::print_migraphx_ep_flags() { LOGS_DEFAULT(WARNING) << "\n device_id: " << info_.device_id << "\n migraphx_fp16_enable: " << fp16_enable_ + << "\n migraphx_bf16_enable: " << bf16_enable_ << "\n migraphx_fp8_enable: " << fp8_enable_ << "\n migraphx_int8_enable: " << int8_enable_ << "\n dump_model_ops: " << dump_model_ops_ @@ -354,6 +382,7 @@ static bool IsTypeSupported(const NodeArg* node_arg) { switch (type_proto->tensor_type().elem_type()) { case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FNUZ: @@ -384,6 +413,9 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: mgx_type = migraphx_shape_half_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + mgx_type = migraphx_shape_bf16_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: mgx_type = migraphx_shape_float_type; break; @@ -1275,6 +1307,7 @@ void calibrate_and_quantize(migraphx::program& prog, const migraphx::target& t, const migraphx::program_parameters quant_params, bool fp16_enable, + bool bf16_enable, bool int8_enable, bool fp8_enable, bool int8_calibration_cache_available, @@ -1317,6 +1350,14 @@ void calibrate_and_quantize(migraphx::program& prog, migraphx::quantize_fp16(prog); LOGS_DEFAULT(WARNING) << "Quantizing fp16: Complete"; } + +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4 && HIP_VERSION_PATCH >= 2) + if (bf16_enable) { + LOGS_DEFAULT(WARNING) << "Quantizing input program to bf16"; + migraphx::quantize_bf16(prog); + LOGS_DEFAULT(WARNING) << "Quantizing bf16: Complete"; + } +#endif } void compile_program(migraphx::program& prog, @@ -1372,7 +1413,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); migraphx::program_parameters quant_params; - calibrate_and_quantize(prog, t_, quant_params, fp16_enable_, int8_enable_, + calibrate_and_quantize(prog, t_, quant_params, fp16_enable_, bf16_enable_, int8_enable_, fp8_enable_, int8_calibration_cache_available_, dynamic_range_map_); compile_program(prog, t_, exhaustive_tune_); save_compiled_model(prog, save_compiled_model_, save_compiled_path_); @@ -1396,7 +1437,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::unique_ptr p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, - map_no_input_shape_[context->node_name], fp16_enable_, fp8_enable_, int8_enable_, + map_no_input_shape_[context->node_name], fp16_enable_, bf16_enable_, fp8_enable_, int8_enable_, int8_calibration_cache_available_, dynamic_range_map_, save_compiled_model_, save_compiled_path_, load_compiled_model_, load_compiled_path_, dump_model_ops_}; @@ -1421,6 +1462,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::onnx_options& cmp_options = mgx_state->options; bool& no_input_shape = mgx_state->no_input_shape; bool fp16_enable = mgx_state->fp16_enable; + bool bf16_enable = mgx_state->bf16_enable; bool fp8_enable = mgx_state->fp8_enable; bool int8_enable = mgx_state->int8_enable; bool int8_calibration_cache_available = mgx_state->int8_calibration_cache_available; @@ -1507,7 +1549,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& } } } - calibrate_and_quantize(prog, t, quant_params, fp16_enable, int8_enable, + calibrate_and_quantize(prog, t, quant_params, fp16_enable, bf16_enable, int8_enable, fp8_enable, int8_calibration_cache_available, map_dynamic_range); compile_program(prog, t, exhaustive_tune_); save_compiled_model(prog, mgx_state->save_compiled_mode, mgx_state->save_compiled_path); diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index aecccdd54d697..45bf5e6a5c2b8 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -17,6 +17,7 @@ namespace onnxruntime { namespace migraphx_env_vars { static const char kFP16Enable[] = "ORT_MIGRAPHX_FP16_ENABLE"; +static const char kBF16Enable[] = "ORT_MIGRAPHX_BF16_ENABLE"; static const char kFP8Enable[] = "ORT_MIGRAPHX_FP8_ENABLE"; static const char kINT8Enable[] = "ORT_MIGRAPHX_INT8_ENABLE"; static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; @@ -44,6 +45,7 @@ struct MIGraphXFuncState { std::mutex* mgx_mu_ptr = nullptr; bool no_input_shape = false; bool fp16_enable = false; + bool bf16_enable = false; bool fp8_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; @@ -100,6 +102,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { private: MIGraphXExecutionProviderInfo info_; bool fp16_enable_ = false; + bool bf16_enable_ = false; bool fp8_enable_ = false; bool int8_enable_ = false; std::string int8_calibration_cache_name_; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index cf21d791cfe6b..bccd3439fa2c4 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -21,6 +21,7 @@ namespace migraphx { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; constexpr const char* kFp16Enable = "trt_fp16_enable"; +constexpr const char* kBf16Enable = "migx_bf16_enable"; constexpr const char* kFp8Enable = "migx_fp8_enable"; constexpr const char* kInt8Enable = "migx_int8_enable"; constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name"; @@ -83,6 +84,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions return Status::OK(); }) .AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable) + .AddAssignmentToReference(migraphx::provider_option_names::kBf16Enable, info.bf16_enable) .AddAssignmentToReference(migraphx::provider_option_names::kFp8Enable, info.fp8_enable) .AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable) .AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model) @@ -102,6 +104,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, + {migraphx::provider_option_names::kBf16Enable, MakeStringWithClassicLocale(info.bf16_enable)}, {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.fp8_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)}, @@ -121,6 +124,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, + {migraphx::provider_option_names::kBf16Enable, MakeStringWithClassicLocale(info.migraphx_bf16_enable)}, {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.migraphx_fp8_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)}, diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index a598052c5f025..7af185ee3c269 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -42,6 +42,7 @@ struct MIGraphXExecutionProviderInfo { std::string target_device; OrtDevice::DeviceId device_id{0}; bool fp16_enable{false}; + bool bf16_enable{false}; bool fp8_enable{false}; bool int8_enable{false}; std::string int8_calibration_table_name{""}; @@ -77,7 +78,9 @@ struct std::hash<::onnxruntime::MIGraphXExecutionProviderInfo> { (static_cast(info.int8_use_native_calibration_table) << 20) ^ (static_cast(info.save_compiled_model) << 21) ^ (static_cast(info.load_compiled_model) << 22) ^ - (static_cast(info.exhaustive_tune) << 23); + (static_cast(info.exhaustive_tune) << 23) ^ + (static_cast(info.bf16_enable) << 24); + onnxruntime::HashCombine(data, value); onnxruntime::HashCombine(info.mem_limit, value); diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index f10ba87e88002..62a3bd1b1f7e3 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -82,6 +82,7 @@ struct MIGraphX_Provider : Provider { info.device_id = static_cast(options.device_id); info.target_device = "gpu"; info.fp16_enable = options.migraphx_fp16_enable; + info.bf16_enable = options.migraphx_bf16_enable; info.fp8_enable = options.migraphx_fp8_enable; info.exhaustive_tune = options.migraphx_exhaustive_tune; info.int8_enable = options.migraphx_int8_enable; @@ -110,6 +111,7 @@ struct MIGraphX_Provider : Provider { auto& migx_options = *reinterpret_cast(provider_options); migx_options.device_id = internal_options.device_id; migx_options.migraphx_fp16_enable = internal_options.fp16_enable; + migx_options.migraphx_bf16_enable = internal_options.bf16_enable; migx_options.migraphx_fp8_enable = internal_options.fp8_enable; migx_options.migraphx_int8_enable = internal_options.int8_enable; migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index acf0681cf8752..c3fcfa0a26936 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -989,6 +989,16 @@ static std::shared_ptr CreateExecutionProviderFactory "[ERROR] [MIGraphX] The value for the key 'migraphx_fp16_enable' should be" " 'True' or 'False'. Default value is 'False'.\n"); } + } else if (option.first == "migraphx_bf16_enable") { + if (option.second == "True" || option.second == "true") { + params.migraphx_bf16_enable = true; + } else if (option.second == "False" || option.second == "false") { + params.migraphx_bf16_enable = false; + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migraphx_bf16_enable' should be" + " 'True' or 'False'. Default value is 'False'.\n"); + } } else if (option.first == "migraphx_fp8_enable") { if (option.second == "True" || option.second == "true") { params.migraphx_fp8_enable = true; From 5506dfabd0b6ea754cbee02508a91ae333e608bc Mon Sep 17 00:00:00 2001 From: urpetkov-amd <127323899+urpetkov-amd@users.noreply.github.com> Date: Mon, 28 Jul 2025 02:43:34 -0600 Subject: [PATCH 02/46] Fixing corrupted flags issue DML and MGX --- .../python/onnxruntime_pybind_ortvalue.cc | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index d1d4d6f3cdad5..a8131a0df5dbc 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -383,21 +383,27 @@ void addOrtValueMethods(pybind11::module& m) { // Converts Tensor into a numpy array .def("numpy", [](const OrtValue* ml_value) -> py::object { ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are convertible to Numpy objects"); - + const auto& device = ml_value->Get().Location().device; + switch (device.Vendor()) { #ifdef USE_CUDA - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetCudaToHostMemCpyFunction()); -#elif USE_ROCM - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetRocmToHostMemCpyFunction()); -#elif USE_CANN - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetCannToHostMemCpyFunction()); -#elif USE_DML - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetDmlToHostMemCpyFunction()); -#elif USE_MIGRAPHX - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetMIGraphXToHostMemCpyFunction()); -#else - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, nullptr); + case OrtDevice::VendorIds::NVIDIA: + return GetPyObjFromTensor(*ml_value, nullptr, GetCudaToHostMemCpyFunction()); +#endif +#ifdef USE_MIGRAPHX + case OrtDevice::VendorIds::AMD: + return GetPyObjFromTensor(*ml_value, nullptr, GetMIGraphXToHostMemCpyFunction()); +#endif +#ifdef USE_DML + case OrtDevice::VendorIds::MICROSOFT: + return GetPyObjFromTensor(*ml_value, nullptr, GetDmlToHostMemCpyFunction()); +#endif +#ifdef USE_CANN + case OrtDevice::VendorIds::HUAWEI: + return GetPyObjFromTensor(*ml_value, nullptr, GetCannToHostMemCpyFunction()); #endif - return obj; }) + default: + return GetPyObjFromTensor(*ml_value, nullptr, nullptr); + } }) #if defined(ENABLE_DLPACK) .def("to_dlpack", [](OrtValue* ort_value) -> py::object { return py::reinterpret_steal(ToDlpack(*ort_value)); }, "Returns a DLPack representing the tensor. This method does not copy the pointer shape, " From fa11154ff15364ad616d8da9d93667d0b320c621 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Mon, 28 Jul 2025 17:15:23 +0200 Subject: [PATCH 03/46] Check migraphx_onnx_options_set_external_data_path exists --- cmake/CMakeLists.txt | 1 + cmake/onnxruntime_providers_migraphx.cmake | 9 +++++++++ .../providers/migraphx/migraphx_execution_provider.cc | 9 +++++++-- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index a76be16572a03..e00356c5c3edb 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -29,6 +29,7 @@ include(CheckLanguage) include(CMakeDependentOption) include(FetchContent) include(CheckFunctionExists) +include(CheckSymbolExists) include(GNUInstallDirs) # onnxruntime_providers_* require CMAKE_INSTALL_* variables # TODO: update this once all system adapt c++20 diff --git a/cmake/onnxruntime_providers_migraphx.cmake b/cmake/onnxruntime_providers_migraphx.cmake index 495ff093326ad..c3bbfe5a09cf9 100644 --- a/cmake/onnxruntime_providers_migraphx.cmake +++ b/cmake/onnxruntime_providers_migraphx.cmake @@ -62,6 +62,15 @@ target_link_libraries(onnxruntime_providers_migraphx PRIVATE stdc++fs) endif() + set(CMAKE_REQUIRED_LIBRARIES migraphx::c) + + check_symbol_exists(migraphx_onnx_options_set_external_data_path + "migraphx/migraphx.h" HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH) + + if(HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH) + target_compile_definitions(onnxruntime_providers_migraphx PRIVATE HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH=1) + endif() + if (onnxruntime_ENABLE_TRAINING_OPS) onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_training) target_link_libraries(onnxruntime_providers_migraphx PRIVATE onnxruntime_training) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index e9648548781fa..8cbb6904ce110 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1410,6 +1410,11 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (!no_input_shape) { if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { LOGS_DEFAULT(INFO) << "No input shapes detected quantizing model"; +#ifndef ENABLE_TRAINING_CORE +#ifdef HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH + options.set_external_data_path(model_path_.parent_path().string()); +#endif +#endif prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); migraphx::program_parameters quant_params; @@ -1521,8 +1526,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { LOGS_DEFAULT(VERBOSE) << "Input shape mismatch detected. Recompiling" << std::endl; #ifndef ENABLE_TRAINING_CORE -#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 2) - cmp_options.set_external_data_path(model_path_.has_parent_path() ? model_path_.parent_path().string() : std::filesystem::current_path().string()); +#ifdef HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH + cmp_options.set_external_data_path(model_path_.parent_path().string()); #endif #endif prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); From 131512d8b07db9c6d4d24534dfb4cd811cf601d6 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Thu, 24 Jul 2025 01:32:07 +0200 Subject: [PATCH 04/46] Use #ifdef instead of #if --- onnxruntime/python/onnxruntime_pybind_mlvalue.cc | 1 + onnxruntime/python/onnxruntime_pybind_state_common.h | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 958c9fc46bcd8..431fb0f422b81 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -207,6 +207,7 @@ std::unique_ptr GetGPUDataTransfer() { #endif #ifdef USE_MIGRAPHX + void CpuToMIGraphXMemCpy(void* dst, const void* src, size_t num_bytes) { GetProviderInfo_MIGraphX().MIGraphXMemcpy_HostToDevice(dst, src, num_bytes); } diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 706c151936192..a73b701a36ddb 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -40,7 +40,7 @@ struct OrtStatus { #define BACKEND_PROC "CPU" #endif -#if USE_DNNL +#ifdef USE_DNNL #define BACKEND_DNNL "-DNNL" #else #define BACKEND_DNNL "" From 8e158650e69f9e988a0eaf8883f7fb47b8b03c26 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Thu, 24 Jul 2025 01:12:16 +0200 Subject: [PATCH 05/46] Simlify MIGraphX EP CMake --- cmake/onnxruntime_providers_migraphx.cmake | 24 ++++++---------------- cmake/onnxruntime_unittests.cmake | 4 ---- 2 files changed, 6 insertions(+), 22 deletions(-) diff --git a/cmake/onnxruntime_providers_migraphx.cmake b/cmake/onnxruntime_providers_migraphx.cmake index c3bbfe5a09cf9..9984dc322f5ec 100644 --- a/cmake/onnxruntime_providers_migraphx.cmake +++ b/cmake/onnxruntime_providers_migraphx.cmake @@ -2,21 +2,11 @@ # Licensed under the MIT License. add_definitions(-DUSE_MIGRAPHX=1) - set(BUILD_LIBRARY_ONLY 1) - add_definitions("-DONNX_ML=1") - add_definitions("-DONNX_NAMESPACE=onnx") - include_directories(${protobuf_SOURCE_DIR} ${eigen_SOURCE_DIR}) - set(MIGRAPHX_ROOT ${onnxruntime_MIGRAPHX_HOME}) - include_directories(${onnx_SOURCE_DIR}) + include_directories(${protobuf_SOURCE_DIR} ${eigen_SOURCE_DIR} ${onnx_SOURCE_DIR}) set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) - if ( CMAKE_COMPILER_IS_GNUCC ) + if (CMAKE_COMPILER_IS_GNUCC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter -Wno-missing-field-initializers") endif() - set(CXX_VERSION_DEFINED TRUE) - set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS}) - if ( CMAKE_COMPILER_IS_GNUCC ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") - endif() # Add search paths for default rocm installation list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hcc /opt/rocm/hip /opt/rocm $ENV{HIP_PATH}) @@ -33,8 +23,6 @@ find_package(hip REQUIRED) find_package(migraphx REQUIRED PATHS ${AMD_MIGRAPHX_HOME}) - set(migraphx_libs migraphx::c hip::host) - file(GLOB_RECURSE onnxruntime_providers_migraphx_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.h" "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.cc" @@ -44,12 +32,12 @@ source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_migraphx_cc_srcs}) onnxruntime_add_shared_library_module(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) - add_dependencies(onnxruntime_providers_migraphx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) - target_link_libraries(onnxruntime_providers_migraphx PRIVATE ${migraphx_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) - target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime) + add_dependencies(onnxruntime_providers_migraphx ${onnxruntime_EXTERNAL_DEPENDENCIES}) + target_link_libraries(onnxruntime_providers_migraphx PRIVATE migraphx::c hip::host ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) + target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/migraphx/onnxruntime) set_target_properties(onnxruntime_providers_migraphx PROPERTIES LINKER_LANGUAGE CXX) set_target_properties(onnxruntime_providers_migraphx PROPERTIES FOLDER "ONNXRuntime") - target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1) + target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1 ONNX_ML=1 ONNX_NAMESPACE=onnx) if(MSVC) set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS /DEF:${ONNXRUNTIME_ROOT}/core/providers/migraphx/symbols.def) target_link_libraries(onnxruntime_providers_migraphx PRIVATE ws2_32) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index c3bebba3bab54..b6062d0154db6 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -610,7 +610,6 @@ endif() if(onnxruntime_USE_MIGRAPHX) list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_migraphx) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx onnxruntime_providers_shared) endif() if(onnxruntime_USE_COREML) @@ -691,9 +690,6 @@ endif() if(onnxruntime_USE_MIGRAPHX) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/migraphx/*) - list(APPEND onnxruntime_test_framework_src_patterns "${ONNXRUNTIME_ROOT}/core/providers/migraphx/migraphx_execution_provider_utils.h") - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_migraphx) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx onnxruntime_providers_shared) endif() if(onnxruntime_USE_NNAPI_BUILTIN) From e70b8dc740def29f35b5998c79600c645a77cd93 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Thu, 24 Jul 2025 01:00:07 +0200 Subject: [PATCH 06/46] Use SHARED lib not MODULE for MIGraphX EP --- cmake/onnxruntime_providers_migraphx.cmake | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/cmake/onnxruntime_providers_migraphx.cmake b/cmake/onnxruntime_providers_migraphx.cmake index 9984dc322f5ec..626ac211d0a6c 100644 --- a/cmake/onnxruntime_providers_migraphx.cmake +++ b/cmake/onnxruntime_providers_migraphx.cmake @@ -30,7 +30,7 @@ "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" ) source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_migraphx_cc_srcs}) - onnxruntime_add_shared_library_module(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs}) + onnxruntime_add_shared_library(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) add_dependencies(onnxruntime_providers_migraphx ${onnxruntime_EXTERNAL_DEPENDENCIES}) target_link_libraries(onnxruntime_providers_migraphx PRIVATE migraphx::c hip::host ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) @@ -67,16 +67,9 @@ endif() endif() - if(CMAKE_SYSTEM_NAME STREQUAL "Windows") - install(TARGETS onnxruntime_providers_migraphx - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - ) - else() - install(TARGETS onnxruntime_providers_migraphx - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - ) - endif() + install(TARGETS onnxruntime_providers_migraphx + EXPORT onnxruntime_providers_migraphxTargets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) From d74ecbaf502f236b556a3fe20793ad0c0403f72a Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Thu, 24 Jul 2025 00:49:26 +0200 Subject: [PATCH 07/46] Use integer type equivalent to architecture to represent memory addresses --- .../providers/migraphx/migraphx_execution_provider_info.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index bccd3439fa2c4..f53b9510f504d 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -62,7 +62,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions .AddValueParser( migraphx::provider_option_names::kGpuExternalAlloc, [&alloc](const std::string& value_str) -> Status { - size_t address; + std::uintptr_t address; ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); alloc = reinterpret_cast(address); return Status::OK(); @@ -70,7 +70,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions .AddValueParser( migraphx::provider_option_names::kGpuExternalFree, [&free](const std::string& value_str) -> Status { - size_t address; + std::uintptr_t address; ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); free = reinterpret_cast(address); return Status::OK(); @@ -78,7 +78,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions .AddValueParser( migraphx::provider_option_names::kGpuExternalEmptyCache, [&empty_cache](const std::string& value_str) -> Status { - size_t address; + std::uintptr_t address; ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); empty_cache = reinterpret_cast(address); return Status::OK(); From 5bb4cee736ea8054e8fd95770e24cee67e08253a Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Thu, 24 Jul 2025 00:53:43 +0200 Subject: [PATCH 08/46] Declaration of 'param_shapes' hides previous local declaration --- .../providers/migraphx/migraphx_execution_provider.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 8cbb6904ce110..112df7384e673 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1534,9 +1534,9 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::program_parameters quant_params; if ((int8_enable xor fp8_enable) and int8_calibration_cache_available) { - auto param_shapes = prog.get_parameter_shapes(); + auto local_param_shapes = prog.get_parameter_shapes(); // Add input parameter data and the values they're set to - for (auto&& name : param_shapes.names()) { + for (auto&& name : local_param_shapes.names()) { if (map_input_name_index.count(name) > 0) { auto input_tensor = ctx.GetInput(map_input_name_index[name]); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); @@ -1545,12 +1545,12 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx_shape_datatype_t mgx_type; getMIGraphXType(tensor_type, mgx_type); - auto mgx_s = param_shapes[name]; + auto mgx_s = local_param_shapes[name]; if (mgx_type != mgx_s.type()) { LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; } - quant_params.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); + quant_params.add(name, migraphx::argument(local_param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); } } } From 982041381dcf8a37ef3f24a3b08907cafca948a4 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Thu, 24 Jul 2025 00:55:46 +0200 Subject: [PATCH 09/46] Add inline to header file function definitions --- .../migraphx_execution_provider_utils.h | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h index 9274b5696185c..e70d58b16c8d9 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h @@ -19,7 +19,7 @@ namespace fs = std::filesystem; namespace onnxruntime { -bool IsGraphInput(const GraphViewer& graph, const std::string& name) { +inline bool IsGraphInput(const GraphViewer& graph, const std::string& name) { const auto& graph_inputs = graph.GetInputs(); std::vector input_names(graph_inputs.size()); std::transform(graph_inputs.begin(), graph_inputs.end(), input_names.begin(), [](auto in) { @@ -28,12 +28,12 @@ bool IsGraphInput(const GraphViewer& graph, const std::string& name) { return (std::find(input_names.begin(), input_names.end(), name) != input_names.end()); } -bool IsGraphInitializer(const GraphViewer& graph, const std::string& name, [[maybe_unused]] bool check_outer_scope = true) { +inline bool IsGraphInitializer(const GraphViewer& graph, const std::string& name, [[maybe_unused]] bool check_outer_scope = true) { const ONNX_NAMESPACE::TensorProto* initializer = nullptr; return graph.GetInitializedTensor(name, initializer); } -const Node* GetInputNode(const Node& node, int arg_index) { +inline const Node* GetInputNode(const Node& node, int arg_index) { int index = 0; for (auto nit = node.InputNodesBegin(); nit != node.InputNodesEnd(); ++nit, ++index) { if (index == arg_index) { @@ -44,7 +44,7 @@ const Node* GetInputNode(const Node& node, int arg_index) { return nullptr; } -std::size_t getNodeInputNum(const Node& node) { +inline std::size_t getNodeInputNum(const Node& node) { std::size_t node_num = 0; for (auto it = node.InputNodesBegin(); it != node.InputNodesEnd(); ++it) { node_num++; @@ -53,14 +53,14 @@ std::size_t getNodeInputNum(const Node& node) { return node_num; } -bool isInputNode(const Node* node, const std::string& name) { +inline bool isInputNode(const Node* node, const std::string& name) { auto outputs = node->OutputDefs(); return std::any_of(outputs.begin(), outputs.end(), [&](auto out) { return (out->Name() == name); }); } -bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector& input_nodes) { +inline bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector& input_nodes) { if (node == nullptr) { return false; } @@ -113,10 +113,10 @@ bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector return true; } -bool canEvalNodeArgument(const GraphViewer& graph, - const Node* node, - std::vector indices, - std::vector& input_nodes) { +inline bool canEvalNodeArgument(const GraphViewer& graph, + const Node* node, + std::vector indices, + std::vector& input_nodes) { input_nodes.clear(); std::vector in_nodes; for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit) { @@ -152,7 +152,7 @@ bool canEvalNodeArgument(const GraphViewer& graph, return true; } -float ConvertSinglePrecisionIEEE754ToFloat(uint32_t input) { +inline float ConvertSinglePrecisionIEEE754ToFloat(uint32_t input) { int s = (input >> 31) & 0x01; int e = ((input & 0x7f800000) >> 23) - 127; int p = -1; @@ -184,10 +184,10 @@ float ConvertSinglePrecisionIEEE754ToFloat(uint32_t input) { * Taken from the tensorRT EP to allow MIGraphX EP to reuse calibration tables for existing models * */ -bool ReadDynamicRange(const std::string file_name, - const bool is_calibration_table, - std::unordered_map& dynamic_range_map) { +inline bool ReadDynamicRange(const std::string file_name, + const bool is_calibration_table, + std::unordered_map& dynamic_range_map) { std::ifstream infile(file_name, std::ios::binary | std::ios::in); if (!infile) { return false; @@ -240,7 +240,7 @@ bool ReadDynamicRange(const std::string file_name, * Get cache by name * */ -std::string GetCachePath(const std::string& root, const std::string& name) { +inline std::string GetCachePath(const std::string& root, const std::string& name) { if (root.empty()) { return name; } else { From e7ce7ff23ff13601d17d68b51da5bf8a0211253b Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Mon, 28 Jul 2025 19:36:26 +0200 Subject: [PATCH 10/46] use string_view --- .../core/framework/provider_options_utils.h | 47 +++++++++++++++++++ .../providers/shared_library/provider_api.h | 6 +++ 2 files changed, 53 insertions(+) diff --git a/include/onnxruntime/core/framework/provider_options_utils.h b/include/onnxruntime/core/framework/provider_options_utils.h index 5967fb91523d0..e2c25dde24054 100644 --- a/include/onnxruntime/core/framework/provider_options_utils.h +++ b/include/onnxruntime/core/framework/provider_options_utils.h @@ -89,6 +89,21 @@ class ProviderOptionsParser { return *this; } + template + ProviderOptionsParser& AddValueParser( + std::string_view name, ValueParserType value_parser) { + ORT_ENFORCE( + value_parsers_.emplace(name, ValueParser{value_parser}).second, + "Provider option \"", name, "\" already has a value parser."); + return *this; + } + + template + ProviderOptionsParser& AddValueParser( + const char* name, ValueParserType value_parser) { + return AddValueParser(std::string_view{name}, value_parser); + } + /** * Adds a parser for a particular provider option value which converts a * value to the right type and assigns it to the given reference. @@ -111,6 +126,22 @@ class ProviderOptionsParser { }); } + template + ProviderOptionsParser& AddAssignmentToReference( + std::string_view name, ValueType& dest) { + return AddValueParser( + name, + [&dest](std::string_view value_str) -> Status { + return ParseStringWithClassicLocale(value_str, dest); + }); + } + + template + ProviderOptionsParser& AddAssignmentToReference( + const char* name, ValueType& dest) { + return AddAssignmentToReference(std::string_view{name}, dest); + } + /** * Adds a parser for a particular provider option value which maps an * enumeration name to a value and assigns it to the given reference. @@ -135,6 +166,22 @@ class ProviderOptionsParser { }); } + template + ProviderOptionsParser& AddAssignmentToEnumReference( + std::string_view name, const EnumNameMapping& mapping, EnumType& dest) { + return AddValueParser( + name, + [&mapping, &dest](const std::string& value_str) -> Status { + return NameToEnum(mapping, value_str, dest); + }); + } + + template + ProviderOptionsParser& AddAssignmentToEnumReference( + const char* name, const EnumNameMapping& mapping, EnumType& dest) { + return AddAssignmentToEnumReference(std::string_view{name}, mapping, dest); + } + /** * Parses the given provider options. */ diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 71d51c4c2992d..1e4a94a63b749 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -326,6 +326,12 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe const logging::Logger& logger); std::string GetEnvironmentVar(const std::string& var_name); +inline std::string GetEnvironmentVar(std::string_view var_name) { + return GetEnvironmentVar(std::string{var_name}); +} +inline std::string GetEnvironmentVar(const char* var_name) { + return GetEnvironmentVar(std::string_view{var_name}); +} namespace profiling { From f8450f1f4ff9f9506d34ac6a0373eb09131883b9 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Tue, 29 Jul 2025 19:05:36 +0200 Subject: [PATCH 11/46] Use std::string_view for literals --- .../migraphx/migraphx_execution_provider.cc | 2 +- .../migraphx/migraphx_execution_provider.h | 32 ++++--- .../migraphx_execution_provider_info.cc | 96 +++++++------------ .../migraphx_execution_provider_info.h | 23 +++++ .../python/onnxruntime_pybind_state.cc | 30 +++--- 5 files changed, 92 insertions(+), 91 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 112df7384e673..bd997454d62ca 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -291,7 +291,7 @@ void MIGraphXExecutionProvider::get_flags_from_env() { } // dump unsupported ops - const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::dumpModelOps); + const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kDumpModelOps); if (!dump_model_ops_env.empty()) { dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true); LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_DUMP_MODEL_OPS: " << dump_model_ops_; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 45bf5e6a5c2b8..eea186416330d 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -3,6 +3,7 @@ #pragma once +#include #include "core/framework/arena_extend_strategy.h" #include "core/framework/execution_provider.h" #include @@ -13,24 +14,25 @@ #include #include +using namespace std::literals::string_view_literals; + namespace onnxruntime { namespace migraphx_env_vars { -static const char kFP16Enable[] = "ORT_MIGRAPHX_FP16_ENABLE"; -static const char kBF16Enable[] = "ORT_MIGRAPHX_BF16_ENABLE"; -static const char kFP8Enable[] = "ORT_MIGRAPHX_FP8_ENABLE"; -static const char kINT8Enable[] = "ORT_MIGRAPHX_INT8_ENABLE"; -static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; -static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"; -static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH"; -static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"; -static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL"; -static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILED_PATH"; -static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL"; -static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILED_PATH"; -static const char kExhaustiveTune[] = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"; - -}; // namespace migraphx_env_vars +constexpr auto kFP16Enable = "ORT_MIGRAPHX_FP16_ENABLE"sv; +constexpr auto kBF16Enable = "ORT_MIGRAPHX_BF16_ENABLE"sv; +constexpr auto kFP8Enable = "ORT_MIGRAPHX_FP8_ENABLE"sv; +constexpr auto kINT8Enable = "ORT_MIGRAPHX_INT8_ENABLE"sv; +constexpr auto kDumpModelOps = "ORT_MIGRAPHX_DUMP_MODEL_OPS"sv; +constexpr auto kINT8CalibrationTableName = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"sv; +constexpr auto kCachePath = "ORT_MIGRAPHX_CACHE_PATH"sv; +constexpr auto kINT8UseNativeMIGraphXCalibrationTable = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"sv; +constexpr auto kSaveCompiledModel = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL"sv; +constexpr auto kSavedModelPath = "ORT_MIGRAPHX_SAVE_COMPILED_PATH"sv; +constexpr auto kLoadCompiledModel = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL"sv; +constexpr auto kLoadModelPath = "ORT_MIGRAPHX_LOAD_COMPILED_PATH"sv; +constexpr auto kExhaustiveTune = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"sv; +} // namespace migraphx_env_vars // Information to construct kernel function state. struct MIGraphXFuncState { diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index f53b9510f504d..c9df70413e881 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -17,29 +17,6 @@ const EnumNameMapping arena_extend_strategy_mapping{ {ArenaExtendStrategy::kSameAsRequested, "kSameAsRequested"}, }; -namespace migraphx { -namespace provider_option_names { -constexpr const char* kDeviceId = "device_id"; -constexpr const char* kFp16Enable = "trt_fp16_enable"; -constexpr const char* kBf16Enable = "migx_bf16_enable"; -constexpr const char* kFp8Enable = "migx_fp8_enable"; -constexpr const char* kInt8Enable = "migx_int8_enable"; -constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name"; -constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table"; -constexpr const char* kSaveCompiledModel = "migx_save_compiled_model"; -constexpr const char* kSaveModelPath = "migx_save_model_name"; -constexpr const char* kLoadCompiledModel = "migx_load_compiled_model"; -constexpr const char* kLoadModelPath = "migx_load_model_name"; -constexpr const char* kExhaustiveTune = "migx_exhaustive_tune"; -constexpr const char* kMemLimit = "migx_mem_limit"; -constexpr const char* kArenaExtendStrategy = "migx_arena_extend_strategy"; -constexpr const char* kGpuExternalAlloc = "migx_external_alloc"; -constexpr const char* kGpuExternalFree = "migx_external_free"; -constexpr const char* kGpuExternalEmptyCache = "migx_external_empty_cache"; - -} // namespace provider_option_names -} // namespace migraphx - MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { MIGraphXExecutionProviderInfo info{}; void* alloc = nullptr; @@ -48,7 +25,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions ORT_THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( - migraphx::provider_option_names::kDeviceId, + migraphx_provider_option::kDeviceId, [&info](const std::string& value_str) -> Status { ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id)); int num_devices{}; @@ -60,7 +37,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions return Status::OK(); }) .AddValueParser( - migraphx::provider_option_names::kGpuExternalAlloc, + migraphx_provider_option::kGpuExternalAlloc, [&alloc](const std::string& value_str) -> Status { std::uintptr_t address; ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); @@ -68,7 +45,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions return Status::OK(); }) .AddValueParser( - migraphx::provider_option_names::kGpuExternalFree, + migraphx_provider_option::kGpuExternalFree, [&free](const std::string& value_str) -> Status { std::uintptr_t address; ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); @@ -76,22 +53,22 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions return Status::OK(); }) .AddValueParser( - migraphx::provider_option_names::kGpuExternalEmptyCache, + migraphx_provider_option::kGpuExternalEmptyCache, [&empty_cache](const std::string& value_str) -> Status { std::uintptr_t address; ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); empty_cache = reinterpret_cast(address); return Status::OK(); }) - .AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable) - .AddAssignmentToReference(migraphx::provider_option_names::kBf16Enable, info.bf16_enable) - .AddAssignmentToReference(migraphx::provider_option_names::kFp8Enable, info.fp8_enable) - .AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable) - .AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model) - .AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model) - .AddAssignmentToReference(migraphx::provider_option_names::kExhaustiveTune, info.exhaustive_tune) - .AddAssignmentToReference(migraphx::provider_option_names::kMemLimit, info.mem_limit) - .AddAssignmentToEnumReference(migraphx::provider_option_names::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy) + .AddAssignmentToReference(migraphx_provider_option::kFp16Enable, info.fp16_enable) + .AddAssignmentToReference(migraphx_provider_option::kBf16Enable, info.bf16_enable) + .AddAssignmentToReference(migraphx_provider_option::kFp8Enable, info.fp8_enable) + .AddAssignmentToReference(migraphx_provider_option::kInt8Enable, info.int8_enable) + .AddAssignmentToReference(migraphx_provider_option::kSaveCompiledModel, info.save_compiled_model) + .AddAssignmentToReference(migraphx_provider_option::kLoadCompiledModel, info.load_compiled_model) + .AddAssignmentToReference(migraphx_provider_option::kExhaustiveTune, info.exhaustive_tune) + .AddAssignmentToReference(migraphx_provider_option::kMemLimit, info.mem_limit) + .AddAssignmentToEnumReference(migraphx_provider_option::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy) .Parse(options)); MIGraphXExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache}; @@ -102,36 +79,35 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXExecutionProviderInfo& info) { const ProviderOptions options{ - {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, - {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, - {migraphx::provider_option_names::kBf16Enable, MakeStringWithClassicLocale(info.bf16_enable)}, - {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.fp8_enable)}, - {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, - {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)}, - {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)}, - {migraphx::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.mem_limit)}, - {migraphx::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.alloc))}, - {migraphx::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.free))}, - {migraphx::provider_option_names::kGpuExternalEmptyCache, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.empty_cache))}, - {migraphx::provider_option_names::kArenaExtendStrategy, - EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, - {migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.exhaustive_tune)}, + {std::string{migraphx_provider_option::kDeviceId}, MakeStringWithClassicLocale(info.device_id)}, + {std::string{migraphx_provider_option::kFp16Enable}, MakeStringWithClassicLocale(info.fp16_enable)}, + {std::string{migraphx_provider_option::kBf16Enable}, MakeStringWithClassicLocale(info.bf16_enable)}, + {std::string{migraphx_provider_option::kFp8Enable}, MakeStringWithClassicLocale(info.fp8_enable)}, + {std::string{migraphx_provider_option::kInt8Enable}, MakeStringWithClassicLocale(info.int8_enable)}, + {std::string{migraphx_provider_option::kSaveCompiledModel}, MakeStringWithClassicLocale(info.save_compiled_model)}, + {std::string{migraphx_provider_option::kLoadCompiledModel}, MakeStringWithClassicLocale(info.load_compiled_model)}, + {std::string{migraphx_provider_option::kMemLimit}, MakeStringWithClassicLocale(info.mem_limit)}, + {std::string{migraphx_provider_option::kGpuExternalAlloc}, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.alloc))}, + {std::string{migraphx_provider_option::kGpuExternalFree}, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.free))}, + {std::string{migraphx_provider_option::kGpuExternalEmptyCache}, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.empty_cache))}, + {std::string{migraphx_provider_option::kArenaExtendStrategy}, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, + {std::string{migraphx_provider_option::kExhaustiveTune}, MakeStringWithClassicLocale(info.exhaustive_tune)}, }; return options; } ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGraphXProviderOptions& info) { const ProviderOptions options{ - {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, - {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, - {migraphx::provider_option_names::kBf16Enable, MakeStringWithClassicLocale(info.migraphx_bf16_enable)}, - {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.migraphx_fp8_enable)}, - {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, - {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)}, - {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)}, - {migraphx::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.migraphx_mem_limit)}, - {migraphx::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast(info.migraphx_arena_extend_strategy))}, - {migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.migraphx_exhaustive_tune)}, + {std::string{migraphx_provider_option::kDeviceId}, MakeStringWithClassicLocale(info.device_id)}, + {std::string{migraphx_provider_option::kFp16Enable}, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, + {std::string{migraphx_provider_option::kBf16Enable}, MakeStringWithClassicLocale(info.migraphx_bf16_enable)}, + {std::string{migraphx_provider_option::kFp8Enable}, MakeStringWithClassicLocale(info.migraphx_fp8_enable)}, + {std::string{migraphx_provider_option::kInt8Enable}, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, + {std::string{migraphx_provider_option::kSaveCompiledModel}, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)}, + {std::string{migraphx_provider_option::kLoadCompiledModel}, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)}, + {std::string{migraphx_provider_option::kMemLimit}, MakeStringWithClassicLocale(info.migraphx_mem_limit)}, + {std::string{migraphx_provider_option::kArenaExtendStrategy}, EnumToName(arena_extend_strategy_mapping, static_cast(info.migraphx_arena_extend_strategy))}, + {std::string{migraphx_provider_option::kExhaustiveTune}, MakeStringWithClassicLocale(info.migraphx_exhaustive_tune)}, }; return options; } diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index 7af185ee3c269..4a2f4a6521e2c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -5,6 +5,7 @@ #include #include +#include #include "core/framework/ortdevice.h" #include "core/common/hash_combine.h" @@ -12,8 +13,30 @@ #include "core/framework/provider_options.h" #include "core/session/onnxruntime_c_api.h" +using namespace std::literals::string_view_literals; + namespace onnxruntime { +namespace migraphx_provider_option { +constexpr auto kDeviceId = "device_id"sv; +constexpr auto kFp16Enable = "migraphx_fp16_enable"sv; +constexpr auto kBf16Enable = "migraphx_bf16_enable"sv; +constexpr auto kFp8Enable = "migraphx_fp8_enable"sv; +constexpr auto kInt8Enable = "migraphx_int8_enable"sv; +constexpr auto kInt8CalibTable = "migraphx_int8_calibration_table_name"sv; +constexpr auto kInt8UseNativeCalibTable = "migraphx_int8_use_native_calibration_table"sv; +constexpr auto kSaveCompiledModel = "migraphx_save_compiled_model"sv; +constexpr auto kSaveModelPath = "migraphx_save_model_name"sv; +constexpr auto kLoadCompiledModel = "migraphx_load_compiled_model"sv; +constexpr auto kLoadModelPath = "migraphx_load_model_name"sv; +constexpr auto kExhaustiveTune = "migraphx_exhaustive_tune"sv; +constexpr auto kMemLimit = "migraphx_mem_limit"sv; +constexpr auto kArenaExtendStrategy = "migraphx_arena_extend_strategy"sv; +constexpr auto kGpuExternalAlloc = "migraphx_external_alloc"sv; +constexpr auto kGpuExternalFree = "migraphx_external_free"sv; +constexpr auto kGpuExternalEmptyCache = "migraphx_external_empty_cache"sv; +} // namespace migraphx_provider_option + // Information needed to construct MIGraphX execution providers. struct MIGraphXExecutionProviderExternalAllocatorInfo { void* alloc{nullptr}; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index c3fcfa0a26936..fc640269fa661 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -979,7 +979,7 @@ static std::shared_ptr CreateExecutionProviderFactory } else { ORT_THROW("[ERROR] [MIGraphX] The value for the key 'device_id' should be a number i.e. '0'.\n"); } - } else if (option.first == "migraphx_fp16_enable") { + } else if (option.first == migraphx_provider_option::kFp16Enable) { if (option.second == "True" || option.second == "true") { params.migraphx_fp16_enable = true; } else if (option.second == "False" || option.second == "false") { @@ -989,7 +989,7 @@ static std::shared_ptr CreateExecutionProviderFactory "[ERROR] [MIGraphX] The value for the key 'migraphx_fp16_enable' should be" " 'True' or 'False'. Default value is 'False'.\n"); } - } else if (option.first == "migraphx_bf16_enable") { + } else if (option.first == migraphx_provider_option::kBf16Enable) { if (option.second == "True" || option.second == "true") { params.migraphx_bf16_enable = true; } else if (option.second == "False" || option.second == "false") { @@ -999,7 +999,7 @@ static std::shared_ptr CreateExecutionProviderFactory "[ERROR] [MIGraphX] The value for the key 'migraphx_bf16_enable' should be" " 'True' or 'False'. Default value is 'False'.\n"); } - } else if (option.first == "migraphx_fp8_enable") { + } else if (option.first == migraphx_provider_option::kFp8Enable) { if (option.second == "True" || option.second == "true") { params.migraphx_fp8_enable = true; } else if (option.second == "False" || option.second == "false") { @@ -1009,7 +1009,7 @@ static std::shared_ptr CreateExecutionProviderFactory "[ERROR] [MIGraphX] The value for the key 'migraphx_fp8_enable' should be" " 'True' or 'False'. Default value is 'False'.\n"); } - } else if (option.first == "migraphx_int8_enable") { + } else if (option.first == migraphx_provider_option::kInt8Enable) { if (option.second == "True" || option.second == "true") { params.migraphx_int8_enable = true; } else if (option.second == "False" || option.second == "false") { @@ -1019,7 +1019,7 @@ static std::shared_ptr CreateExecutionProviderFactory "[ERROR] [MIGraphX] The value for the key 'migraphx_int8_enable' should be" " 'True' or 'False'. Default value is 'False'.\n"); } - } else if (option.first == "migraphx_int8_calibration_table_name") { + } else if (option.first == migraphx_provider_option::kInt8CalibTable) { if (!option.second.empty()) { calibration_table = option.second; params.migraphx_int8_calibration_table_name = calibration_table.c_str(); @@ -1028,7 +1028,7 @@ static std::shared_ptr CreateExecutionProviderFactory "[ERROR] [MIGraphX] The value for the key 'migraphx_int8_calibration_table_name' should be a " "file name i.e. 'cal_table'.\n"); } - } else if (option.first == "migraphx_use_native_calibration_table") { + } else if (option.first == migraphx_provider_option::kInt8UseNativeCalibTable) { if (option.second == "True" || option.second == "true") { params.migraphx_use_native_calibration_table = true; } else if (option.second == "False" || option.second == "false") { @@ -1038,17 +1038,17 @@ static std::shared_ptr CreateExecutionProviderFactory "[ERROR] [MIGraphX] The value for the key 'migraphx_use_native_calibration_table' should be" " 'True' or 'False'. Default value is 'False'.\n"); } - } else if (option.first == "migraphx_save_compiled_model") { + } else if (option.first == migraphx_provider_option::kSaveCompiledModel) { if (option.second == "True" || option.second == "true") { - params.migraphx_fp16_enable = true; + params.migraphx_save_compiled_model = true; } else if (option.second == "False" || option.second == "false") { - params.migraphx_fp16_enable = false; + params.migraphx_save_compiled_model = false; } else { ORT_THROW( "[ERROR] [MIGraphX] The value for the key 'migraphx_save_compiled_model' should be" " 'True' or 'False'. Default value is 'False'.\n"); } - } else if (option.first == "migraphx_save_model_path") { + } else if (option.first == migraphx_provider_option::kSaveModelPath) { if (!option.second.empty()) { save_model_path = option.second; params.migraphx_save_model_path = save_model_path.c_str(); @@ -1057,17 +1057,17 @@ static std::shared_ptr CreateExecutionProviderFactory "[ERROR] [MIGraphX] The value for the key 'migraphx_save_model_name' should be a " "file name i.e. 'compiled_model.mxr'.\n"); } - } else if (option.first == "migraphx_load_compiled_model") { + } else if (option.first == migraphx_provider_option::kLoadCompiledModel) { if (option.second == "True" || option.second == "true") { - params.migraphx_fp16_enable = true; + params.migraphx_load_compiled_model = true; } else if (option.second == "False" || option.second == "false") { - params.migraphx_fp16_enable = false; + params.migraphx_load_compiled_model = false; } else { ORT_THROW( "[ERROR] [MIGraphX] The value for the key 'migraphx_load_compiled_model' should be" " 'True' or 'False'. Default value is 'False'.\n"); } - } else if (option.first == "migraphx_load_model_path") { + } else if (option.first == migraphx_provider_option::kLoadModelPath) { if (!option.second.empty()) { load_model_path = option.second; params.migraphx_load_model_path = load_model_path.c_str(); @@ -1076,7 +1076,7 @@ static std::shared_ptr CreateExecutionProviderFactory "[ERROR] [MIGraphX] The value for the key 'migraphx_load_model_name' should be a " "file name i.e. 'compiled_model.mxr'.\n"); } - } else if (option.first == "migraphx_exhaustive_tune") { + } else if (option.first == migraphx_provider_option::kExhaustiveTune) { if (option.second == "True" || option.second == "true") { params.migraphx_exhaustive_tune = true; } else if (option.second == "False" || option.second == "false") { From 231462574bbf2fba0df88e7929a3d3d7b0ed3661 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Thu, 29 May 2025 10:36:21 +0200 Subject: [PATCH 12/46] Use MIGraphX ONNXRT DLL location to search for MIGX and HIP runtime DLLs --- .../migraphx/migraphx_provider_factory.cc | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 62a3bd1b1f7e3..751e45af5c6a3 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -2,6 +2,12 @@ // Licensed under the MIT License #include +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#include +#include +#endif + #include "core/providers/shared_library/provider_api.h" #include "core/providers/migraphx/migraphx_provider_factory.h" #include "migraphx_execution_provider.h" @@ -164,6 +170,23 @@ struct MIGraphX_Provider : Provider { } void Initialize() override { +#ifdef _WIN32 + HMODULE module = nullptr; + if (GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | + GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, + static_cast(static_cast(InitializeRegistry)), + &module) != 0) { + std::vector pathBuf; + for (;;) { + pathBuf.resize(pathBuf.size() + MAX_PATH); + if (const auto writen = GetModuleFileNameW(module, pathBuf.data(), static_cast(pathBuf.size())); writen < pathBuf.size()) { + break; + } + } + std::filesystem::path path(pathBuf.begin(), pathBuf.end()); + SetDllDirectoryW(path.parent_path().native().c_str()); + } +#endif InitializeRegistry(); } From 303129c63c64636c943e7021a892e94acdca1895 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Mon, 28 Jul 2025 16:05:57 +0200 Subject: [PATCH 13/46] Reduce number of log messages --- .../migraphx/migraphx_execution_provider.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index bd997454d62ca..25317eb81a21c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1278,9 +1278,9 @@ bool get_input_output_names(const GraphViewer& graph, bool load_precompiled_model(migraphx::program& prog, bool load_enable, std::string path) { try { if (load_enable) { - LOGS_DEFAULT(WARNING) << "Attempting to load model at:" << path; + LOGS_DEFAULT(VERBOSE) << "Attempting to load model at:" << path; prog = migraphx::load(path.c_str()); - LOGS_DEFAULT(WARNING) << "load model : Success"; + LOGS_DEFAULT(VERBOSE) << "load model : Success"; return true; } else { return false; @@ -1293,11 +1293,11 @@ bool load_precompiled_model(migraphx::program& prog, bool load_enable, std::stri void save_compiled_model(migraphx::program& prog, bool save_enable, std::string out_path) { if (save_enable) { - LOGS_DEFAULT(WARNING) << "Model Save at " << out_path << ": Begin"; + LOGS_DEFAULT(VERBOSE) << "Model Save at " << out_path << ": Begin"; migraphx::file_options fo; fo.set_file_format("msgpack"); migraphx::save(prog, out_path.c_str(), fo); - LOGS_DEFAULT(WARNING) << "Model Save: Complete"; + LOGS_DEFAULT(VERBOSE) << "Model Save: Complete"; } } @@ -1409,7 +1409,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (!no_input_shape) { if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { - LOGS_DEFAULT(INFO) << "No input shapes detected quantizing model"; + LOGS_DEFAULT(VERBOSE) << "No input shapes detected quantizing model"; #ifndef ENABLE_TRAINING_CORE #ifdef HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH options.set_external_data_path(model_path_.parent_path().string()); @@ -1477,7 +1477,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& bool input_shape_match = true; migraphx::program_parameter_shapes param_shapes; if (no_input_shape) { - LOGS_DEFAULT(INFO) << "Missing input shape setting input parameters again"; + LOGS_DEFAULT(VERBOSE) << "Missing input shape setting input parameters again"; for (auto& it : map_input_name_index) { auto& name = it.first; auto& index = it.second; @@ -1489,7 +1489,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& input_shape_match = false; } } else { - LOGS_DEFAULT(INFO) << "Assigning inputs, and parameters from compiled model"; + LOGS_DEFAULT(VERBOSE) << "Assigning inputs, and parameters from compiled model"; param_shapes = prog.get_parameter_shapes(); auto prog_output_shapes = prog.get_output_shapes(); @@ -1571,7 +1571,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (param_shapes.size() > 0) { for (auto&& name : param_shapes.names()) { if (map_input_name_index.count(name) > 0) { - LOGS_DEFAULT(INFO) << "Setting parameters for:" << name; + LOGS_DEFAULT(VERBOSE) << "Setting parameters for:" << name; auto input_tensor = ctx.GetInput(map_input_name_index[name]); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shape = tensor_info.GetShape(); @@ -1585,7 +1585,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; } - LOGS_DEFAULT(INFO) << "Writing Raw tensor data "; + LOGS_DEFAULT(VERBOSE) << "Writing Raw tensor data "; m.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); } From 7edd17b9cd9064065708380ea670c9dea7aab88c Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Mon, 28 Jul 2025 16:15:36 +0200 Subject: [PATCH 14/46] implement cpplint suggestions --- .../providers/migraphx/gpu_data_transfer.h | 2 +- .../providers/migraphx/migraphx_allocator.cc | 11 +-- .../providers/migraphx/migraphx_allocator.h | 2 +- .../core/providers/migraphx/migraphx_call.cc | 10 +-- .../core/providers/migraphx/migraphx_call.h | 2 +- .../migraphx/migraphx_execution_provider.cc | 90 ++++++++++--------- .../migraphx/migraphx_execution_provider.h | 17 ++-- .../migraphx_execution_provider_info.cc | 6 +- .../migraphx_execution_provider_utils.h | 9 +- .../migraphx/migraphx_provider_factory.cc | 21 +++-- .../migraphx/migraphx_provider_factory.h | 11 ++- .../migraphx/migraphx_stream_handle.cc | 68 +++++++------- .../migraphx/migraphx_stream_handle.h | 11 ++- 13 files changed, 149 insertions(+), 111 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.h b/onnxruntime/core/providers/migraphx/gpu_data_transfer.h index 5918716b3e77f..a4eb8efd2afea 100644 --- a/onnxruntime/core/providers/migraphx/gpu_data_transfer.h +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.h @@ -3,7 +3,7 @@ #pragma once -#include "migraphx_inc.h" +#include "core/providers/migraphx/migraphx_inc.h" #include "core/framework/data_transfer.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc index 1cac133ab0c2c..911a1a7fd18b9 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc @@ -2,12 +2,11 @@ // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" -#include "migraphx_call.h" -#include "migraphx_allocator.h" +#include "core/providers/migraphx/migraphx_call.h" +#include "core/providers/migraphx/migraphx_allocator.h" #include "core/common/status.h" #include "core/framework/float16.h" -#include "core/common/status.h" -#include "gpu_data_transfer.h" +#include "core/providers/migraphx/gpu_data_transfer.h" namespace onnxruntime { @@ -55,7 +54,9 @@ void MIGraphXExternalAllocator::Free(void* p) { auto it = reserved_.find(p); if (it != reserved_.end()) { reserved_.erase(it); - if (empty_cache_) empty_cache_(); + if (empty_cache_ != nullptr) { + empty_cache_(); + } } } diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.h b/onnxruntime/core/providers/migraphx/migraphx_allocator.h index f6b7788e0604c..10e06ab2f35ad 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.h +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.h @@ -3,9 +3,9 @@ #pragma once +#include #include #include "core/framework/allocator.h" -#include namespace onnxruntime { diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.cc b/onnxruntime/core/providers/migraphx/migraphx_call.cc index 9807cd646e51c..61e41ab4c6284 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_call.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_call.cc @@ -1,13 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include + #ifdef _WIN32 #include #else #include #endif -#include #include "core/common/common.h" #include "core/common/status.h" #include "core/providers/shared_library/provider_api.h" @@ -15,8 +17,6 @@ namespace onnxruntime { -using namespace common; - template const char* RocmErrString(ERRTYPE x) { ORT_NOT_IMPLEMENTED(); @@ -48,8 +48,8 @@ std::conditional_t RocmCall( (void)hipGetDevice(¤tHipDevice); (void)hipGetLastError(); // clear last HIP error static char str[1024]; - snprintf(str, 1024, "%s failure %d: %s ; GPU=%d ; hostname=%s ; file=%s ; line=%d ; expr=%s; %s", - libName, (int)retCode, RocmErrString(retCode), currentHipDevice, + snprintf(str, sizeof(str), "%s failure %d: %s ; GPU=%d ; hostname=%s ; file=%s ; line=%d ; expr=%s; %s", + libName, static_cast(retCode), RocmErrString(retCode), currentHipDevice, hostname.c_str(), file, line, exprString, msg); if constexpr (THRW) { diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.h b/onnxruntime/core/providers/migraphx/migraphx_call.h index 6d514e01aea96..64805784ba75f 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_call.h +++ b/onnxruntime/core/providers/migraphx/migraphx_call.h @@ -2,7 +2,7 @@ // Licensed under the MIT License. #pragma once -#include "migraphx_inc.h" +#include "core/providers/migraphx/migraphx_inc.h" #include "core/common/common.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 25317eb81a21c..dd3eede481b9b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1,26 +1,34 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License -#include + +#include + #include +#include +#include +#include #include -#include +#include +#include #include -#include +#include +#include +#include +#include +#include #include "core/providers/shared_library/provider_api.h" #define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" #include "core/common/safeint.h" #include "core/common/logging/severity.h" -#include "migraphx_execution_provider.h" -#include "migraphx_execution_provider_info.h" -#include "migraphx_execution_provider_utils.h" -#include "migraphx_allocator.h" -#include "gpu_data_transfer.h" -#include -#include "migraphx_call.h" - -#include "migraphx_stream_handle.h" +#include "core/providers/migraphx/migraphx_execution_provider.h" +#include "core/providers/migraphx/migraphx_execution_provider_info.h" +#include "core/providers/migraphx/migraphx_execution_provider_utils.h" +#include "core/providers/migraphx/migraphx_allocator.h" +#include "core/providers/migraphx/gpu_data_transfer.h" +#include "core/providers/migraphx/migraphx_call.h" +#include "core/providers/migraphx/migraphx_stream_handle.h" #if defined(_MSC_VER) #pragma warning(disable : 4244 4245) @@ -131,7 +139,7 @@ void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecut bf16_enable_ = info.bf16_enable; #endif - if (bf16_enable_ and fp16_enable_) { + if (bf16_enable_ && fp16_enable_) { bf16_enable_ = false; fp16_enable_ = false; LOGS_DEFAULT(FATAL) << "MIGraphX: BF16 and FP16 Quantization Mutually exclusive. Ignoring both Quantization flags"; @@ -145,18 +153,18 @@ void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecut #endif int8_enable_ = info.int8_enable; - if (int8_enable_ and fp8_enable_) { + if (int8_enable_ && fp8_enable_) { int8_enable_ = false; fp8_enable_ = false; LOGS_DEFAULT(FATAL) << "MIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; } - if (int8_enable_ xor fp8_enable_) { + if (int8_enable_ ^ fp8_enable_) { int8_calibration_cache_name_ = info.int8_calibration_table_name; int8_use_native_migraphx_calibration_table_ = info.int8_use_native_calibration_table; } - if (int8_enable_ or fp8_enable_) { + if (int8_enable_ || fp8_enable_) { int8_calibration_cache_available_ = !info.int8_calibration_table_name.empty(); } @@ -201,7 +209,7 @@ void MIGraphXExecutionProvider::get_flags_from_env() { #endif } - if (bf16_enable_ and fp16_enable_) { + if (bf16_enable_ && fp16_enable_) { LOGS_DEFAULT(FATAL) << "\nMIGraphX: FP16 and BF16 Quantization Mutually exclusive. Ignoring both flags"; } @@ -224,7 +232,7 @@ void MIGraphXExecutionProvider::get_flags_from_env() { LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_ENABLE: " << int8_enable_; } - if (int8_enable_ and fp8_enable_) { + if (int8_enable_ && fp8_enable_) { LOGS_DEFAULT(FATAL) << "\nMIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; } @@ -252,7 +260,7 @@ void MIGraphXExecutionProvider::get_flags_from_env() { } } - if (int8_enable_ or fp8_enable_) { + if (int8_enable_ || fp8_enable_) { int8_calibration_cache_available_ = !int8_calibration_cache_name_.empty(); } @@ -489,7 +497,7 @@ std::vector toVector(const ONNX_NAMESPACE::int64s& nums) { static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, const Node* node) { std::vector input_nodes; const auto& optype = node->OpType(); - if (optype == "ArgMax" or optype == "ArgMin") { + if (optype == "ArgMax" || optype == "ArgMin") { const auto& attributes = node->GetAttributes(); // we do not support select_last_index = 1 for now auto sli_attr = attributes.find("select_last_index"); @@ -507,7 +515,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return true; } - if ((input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) and + if ((input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) && (input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)) { return true; } @@ -535,7 +543,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co // storage order 1 (column major format) is not supported auto storage_order_attr = attributes.find("storage_order"); - if (storage_order_attr != attributes.end() and (*storage_order_attr).second.i() != 0) { + if (storage_order_attr != attributes.end() && (*storage_order_attr).second.i() != 0) { return true; } @@ -545,7 +553,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return true; } auto data_type = input_type->tensor_type().elem_type(); - if (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8 or + if (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8 || data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8) { return true; } @@ -556,7 +564,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return true; } - if ((input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) and + if ((input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) && (input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)) { return true; } @@ -612,7 +620,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co } return true; } - } else if (optype == "Resize" or optype == "Upsample") { + } else if (optype == "Resize" || optype == "Upsample") { const auto& attributes = node->GetAttributes(); auto ct_attr = attributes.find("coordinate_transformation_mode"); if (ct_attr != attributes.end()) { @@ -650,7 +658,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co } const auto& attributes = node->GetAttributes(); - if (attributes.count("starts") > 0 and attributes.count("ends") > 0) { + if (attributes.count("starts") > 0 && attributes.count("ends") > 0) { auto starts = toVector((*attributes.find("starts")).second.ints()); auto ends = toVector((*attributes.find("ends")).second.ints()); for (std::size_t i = 0; i < starts.size(); ++i) { @@ -688,7 +696,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { return true; } - } else if (optype == "Unsqueeze" or optype == "Squeeze") { + } else if (optype == "Unsqueeze" || optype == "Squeeze") { const auto& args = node->InputDefs(); if (args.size() == 2) { if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { @@ -717,9 +725,9 @@ void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::v if (args.size() == 2) { std::vector node_inputs; if (canEvalNodeArgument(graph_viewer, node, {1}, node_inputs)) { - return (not std::all_of(node_inputs.begin(), node_inputs.end(), [&](auto index) { - return std::find(git.begin(), git.end(), index) != git.end(); - })); + return !std::all_of(node_inputs.begin(), node_inputs.end(), [&](auto i) { + return std::find(git.begin(), git.end(), i) != git.end(); + }); } else { return true; } @@ -889,12 +897,14 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st erased.insert(output); } // Only when output is neither in input list nor erased list, add the output to output list - else if (erased.find(output) == erased.end()) { - if (std::find(graph_output_names.begin(), - graph_output_names.end(), output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; + else { + if (erased.find(output) == erased.end()) { + if (std::find(graph_output_names.begin(), + graph_output_names.end(), output->Name()) != graph_output_names.end()) { + graph_outputs_to_add[output] = output_order; + } + fused_outputs[output] = output_order++; } - fused_outputs[output] = output_order++; } } } @@ -1313,7 +1323,7 @@ void calibrate_and_quantize(migraphx::program& prog, bool int8_calibration_cache_available, std::unordered_map& dynamic_range_map) { // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if ((int8_enable xor fp8_enable) && int8_calibration_cache_available) { + if ((int8_enable ^ fp8_enable) && int8_calibration_cache_available) { LOGS_DEFAULT(WARNING) << "Quantizing input program"; auto param_shapes = prog.get_parameter_shapes(); @@ -1506,8 +1516,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& auto mgx_s = param_shapes[name]; auto mgx_lens = mgx_s.lengths(); auto mgx_strides = mgx_s.strides(); - if (mgx_lens.size() == 1 and mgx_lens[0] == 1 and - mgx_strides.size() == 1 and mgx_strides[0] == 0) { + if (mgx_lens.size() == 1 && mgx_lens[0] == 1 && + mgx_strides.size() == 1 && mgx_strides[0] == 0) { mgx_lens.clear(); } @@ -1533,7 +1543,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); migraphx::program_parameters quant_params; - if ((int8_enable xor fp8_enable) and int8_calibration_cache_available) { + if ((int8_enable ^ fp8_enable) && int8_calibration_cache_available) { auto local_param_shapes = prog.get_parameter_shapes(); // Add input parameter data and the values they're set to for (auto&& name : local_param_shapes.names()) { @@ -1646,7 +1656,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& static_cast(rocm_stream))); } } - }; + } return Status::OK(); }; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index eea186416330d..9a54c3045606b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -3,17 +3,20 @@ #pragma once +#include +#include +#include +#include +#include +#include #include +#include +#include #include "core/framework/arena_extend_strategy.h" #include "core/framework/execution_provider.h" -#include #include "core/providers/migraphx/migraphx_execution_provider_info.h" #include "core/providers/migraphx/migraphx_call.h" -#include -#include -#include - using namespace std::literals::string_view_literals; namespace onnxruntime { @@ -64,7 +67,7 @@ struct MIGraphXFuncState { class MIGraphXExecutionProvider : public IExecutionProvider { public: explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info); - ~MIGraphXExecutionProvider(); + ~MIGraphXExecutionProvider() override; void get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info); void get_flags_from_env(); @@ -85,7 +88,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { common::Status Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) override; - virtual std::shared_ptr GetKernelRegistry() const override; + std::shared_ptr GetKernelRegistry() const override; std::unique_ptr GetDataTransfer() const override; static AllocatorPtr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t migx_mem_limit, ArenaExtendStrategy arena_extend_strategy, diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index c9df70413e881..fa9e613597a46 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -1,14 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/providers/shared_library/provider_api.h" #include "core/providers/migraphx/migraphx_execution_provider_info.h" #include "core/common/make_string.h" #include "core/common/parse_string.h" #include "core/framework/provider_options_utils.h" -#include "migraphx_inc.h" -#include "migraphx_call.h" +#include "core/providers/migraphx/migraphx_inc.h" +#include "core/providers/migraphx/migraphx_call.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h index e70d58b16c8d9..0282d3eecf529 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h @@ -3,12 +3,15 @@ #pragma once +#include +#include +#include #include -#include -#include #include -#include #include +#include +#include +#include #include "flatbuffers/idl.h" #include "core/providers/migraphx/ort_trt_int8_cal_table.fbs.h" #include "core/session/onnxruntime_cxx_api.h" diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 751e45af5c6a3..b4868b86b18ef 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -1,6 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License + #include +#include +#include +#include +#include #ifdef _WIN32 #define WIN32_LEAN_AND_MEAN @@ -10,25 +15,23 @@ #include "core/providers/shared_library/provider_api.h" #include "core/providers/migraphx/migraphx_provider_factory.h" -#include "migraphx_execution_provider.h" -#include "migraphx_execution_provider_info.h" -#include "migraphx_provider_factory_creator.h" -#include "migraphx_allocator.h" -#include "gpu_data_transfer.h" +#include "core/providers/migraphx/migraphx_execution_provider.h" +#include "core/providers/migraphx/migraphx_execution_provider_info.h" +#include "core/providers/migraphx/migraphx_provider_factory_creator.h" +#include "core/providers/migraphx/migraphx_allocator.h" +#include "core/providers/migraphx/gpu_data_transfer.h" #include "core/framework/provider_options.h" #include "core/session/onnxruntime_c_api.h" -using namespace onnxruntime; - namespace onnxruntime { void InitializeRegistry(); void DeleteRegistry(); struct MIGraphXProviderFactory : IExecutionProviderFactory { - MIGraphXProviderFactory(const MIGraphXExecutionProviderInfo& info) : info_{info} {} - ~MIGraphXProviderFactory() override {} + explicit MIGraphXProviderFactory(MIGraphXExecutionProviderInfo info) : info_{std::move(info)} {} + ~MIGraphXProviderFactory() override = default; std::unique_ptr CreateProvider() override; diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h index d1c9457bafa0f..0aa5ad172dfae 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h @@ -1,7 +1,12 @@ -// Copyright 2019 AMD AMDMIGraphX +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License -#include "core/framework/provider_options.h" -#include "onnxruntime_c_api.h" +#pragma once + +#include + +#include "core/framework/ortdevice.h" +#include "core/session/onnxruntime_c_api.h" namespace onnxruntime { class IAllocator; diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc index 6e492327a73a3..cf4ccdc292c17 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc @@ -1,17 +1,26 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include "migraphx_stream_handle.h" +#include +#include + +#include "core/providers/resource.h" +#include "core/providers/migraphx/migraphx_stream_handle.h" + +#define MIGRAPHX_RESOURCE_VERSION 1 namespace onnxruntime { -struct MIGraphXNotification : public synchronize::Notification { - MIGraphXNotification(Stream& s) : Notification(s) { +enum class MIGraphXResource { + hip_stream_t = rocm_resource_offset +}; + +struct MIGraphXNotification : synchronize::Notification { + explicit MIGraphXNotification(Stream& s) : Notification(s) { HIP_CALL_THROW(hipEventCreateWithFlags(&event_, hipEventDisableTiming)); } - ~MIGraphXNotification() { + ~MIGraphXNotification() override { if (event_) HIP_CALL_THROW(hipEventDestroy(event_)); } @@ -21,19 +30,19 @@ struct MIGraphXNotification : public synchronize::Notification { HIP_CALL_THROW(hipEventRecord(event_, static_cast(GetStream().GetHandle()))); } - void wait_on_device(Stream& device_stream) { - ORT_ENFORCE(device_stream.GetDevice().Type() == OrtDevice::GPU, "Unexpected device:", - device_stream.GetDevice().ToString()); - // launch a wait command to the migraphx stream - HIP_CALL_THROW(hipStreamWaitEvent(static_cast(device_stream.GetHandle()), event_, 0)); - }; + void wait_on_device(Stream* device_stream) const { + if (device_stream != nullptr) { + ORT_ENFORCE(device_stream->GetDevice().Type() == OrtDevice::GPU, "Unexpected device:", device_stream->GetDevice().ToString()); + // launch a wait command to the migraphx stream + HIP_CALL_THROW(hipStreamWaitEvent(static_cast(device_stream->GetHandle()), event_, 0)); + } + } - void wait_on_host() { - // CUDA_CALL_THROW(cudaStreamSynchronize(stream_)); + void wait_on_host() const { HIP_CALL_THROW(hipEventSynchronize(event_)); } - hipEvent_t event_; + hipEvent_t event_{}; }; MIGraphXStream::MIGraphXStream(hipStream_t stream, @@ -41,15 +50,14 @@ MIGraphXStream::MIGraphXStream(hipStream_t stream, AllocatorPtr cpu_allocator, bool release_cpu_buffer_on_migraphx_stream) : Stream(stream, device), - cpu_allocator_(cpu_allocator), + cpu_allocator_(std::move(cpu_allocator)), release_cpu_buffer_on_migraphx_stream_(release_cpu_buffer_on_migraphx_stream) { } MIGraphXStream::~MIGraphXStream() { - ORT_IGNORE_RETURN_VALUE(CleanUpOnRunEnd()); + ORT_IGNORE_RETURN_VALUE(MIGraphXStream::CleanUpOnRunEnd()); if (own_stream_) { - auto* handle = GetHandle(); - if (handle) + if (auto* handle = GetHandle()) HIP_CALL_THROW(hipStreamDestroy(static_cast(handle))); } } @@ -87,12 +95,12 @@ struct CpuBuffersInfo { std::unique_ptr buffers; // CPU buffer buffers[i]. // Number of buffer points in "buffers". - size_t n_buffers; + size_t n_buffers{}; }; static void ReleaseCpuBufferCallback(void* raw_info) { std::unique_ptr info = std::make_unique(); - info.reset(reinterpret_cast(raw_info)); + info.reset(static_cast(raw_info)); for (size_t i = 0; i < info->n_buffers; ++i) { info->allocator->Free(info->buffers[i]); } @@ -124,29 +132,28 @@ Status MIGraphXStream::CleanUpOnRunEnd() { } void* MIGraphXStream::GetResource(int version, int id) const { - ORT_ENFORCE(version <= ORT_ROCM_RESOURCE_VERSION, "resource version unsupported!"); - void* resource{}; + ORT_ENFORCE(version <= MIGRAPHX_RESOURCE_VERSION, "resource version unsupported!"); switch (id) { - case RocmResource::hip_stream_t: - return reinterpret_cast(GetHandle()); + case MIGraphXResource::hip_stream_t: + return GetHandle(); default: break; } - return resource; + return nullptr; } // CPU Stream command handles void WaitMIGraphXNotificationOnDevice(Stream* stream, synchronize::Notification& notification) { - static_cast(¬ification)->wait_on_device(*stream); + dynamic_cast(¬ification)->wait_on_device(stream); } void WaitMIGraphXNotificationOnHost(Stream* /*stream*/, synchronize::Notification& notification) { - static_cast(¬ification)->wait_on_host(); + dynamic_cast(¬ification)->wait_on_host(); } void RegisterMIGraphXStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, const OrtDevice::DeviceType device_type, - AllocatorPtr cpu_allocator, + const AllocatorPtr& cpu_allocator, bool release_cpu_buffer_on_migraphx_stream, hipStream_t external_stream, bool use_existing_stream) { @@ -154,19 +161,20 @@ void RegisterMIGraphXStreamHandles(IStreamCommandHandleRegistry& stream_handle_r stream_handle_registry.RegisterWaitFn(device_type, device_type, WaitMIGraphXNotificationOnDevice); // wait migraphx notification on cpu ep stream_handle_registry.RegisterWaitFn(device_type, OrtDevice::CPU, WaitMIGraphXNotificationOnHost); - if (!use_existing_stream) + if (!use_existing_stream) { stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_migraphx_stream](const OrtDevice& device) { HIP_CALL_THROW(hipSetDevice(device.Id())); hipStream_t stream = nullptr; HIP_CALL_THROW(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); return std::make_unique(stream, device, cpu_allocator, release_cpu_buffer_on_migraphx_stream); }); - else + } else { stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_migraphx_stream, external_stream](const OrtDevice& device) { return std::make_unique(external_stream, device, cpu_allocator, release_cpu_buffer_on_migraphx_stream); }); + } } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h index 886103690c661..b25eff1a1b9c6 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h @@ -2,12 +2,15 @@ // Licensed under the MIT License. #pragma once + +#include +#include + #include "core/framework/stream_handles.h" -#include "migraphx_inc.h" -#include "migraphx_call.h" +#include "core/providers/migraphx/migraphx_inc.h" +#include "core/providers/migraphx/migraphx_call.h" namespace onnxruntime { -void WaitMIGraphXNotificationOnDevice(Stream* stream, synchronize::Notification& notification); struct MIGraphXStream : Stream { MIGraphXStream(hipStream_t stream, @@ -37,7 +40,7 @@ struct MIGraphXStream : Stream { void RegisterMIGraphXStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, const OrtDevice::DeviceType device_type, - AllocatorPtr cpu_allocator, + const AllocatorPtr& cpu_allocator, bool release_cpu_buffer_on_migraphx_stream, hipStream_t external_stream, bool use_existing_stream); From bc1ca7aba76cc25840461518118327285d2f37ed Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Thu, 24 Jul 2025 01:27:21 +0200 Subject: [PATCH 15/46] Use OrtDevice::DeviceId instead of int --- .../core/providers/migraphx/migraphx_provider_factory.cc | 6 +++--- .../core/providers/migraphx/migraphx_provider_factory.h | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index b4868b86b18ef..8c4bbb15d7c89 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -44,11 +44,11 @@ std::unique_ptr MIGraphXProviderFactory::CreateProvider() { } struct ProviderInfo_MIGraphX_Impl final : ProviderInfo_MIGraphX { - std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) override { + std::unique_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, const char* name) override { return std::make_unique(device_id, name); } - std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) override { + std::unique_ptr CreateMIGraphXPinnedAllocator(OrtDevice::DeviceId device_id, const char* name) override { return std::make_unique(device_id, name); } @@ -70,7 +70,7 @@ struct ProviderInfo_MIGraphX_Impl final : ProviderInfo_MIGraphX { HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyDeviceToHost)); } - std::shared_ptr CreateMIGraphXAllocator(int16_t device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override { + std::shared_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override { return MIGraphXExecutionProvider::CreateMIGraphXAllocator(device_id, migx_mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg); } } g_info; diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h index 0aa5ad172dfae..6baee291b7fe5 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h @@ -17,11 +17,11 @@ enum class ArenaExtendStrategy : int32_t; struct MIGraphXExecutionProviderExternalAllocatorInfo; struct ProviderInfo_MIGraphX { - virtual std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) = 0; - virtual std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) = 0; + virtual std::unique_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, const char* name) = 0; + virtual std::unique_ptr CreateMIGraphXPinnedAllocator(OrtDevice::DeviceId device_id, const char* name) = 0; virtual void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, size_t count) = 0; virtual void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, size_t count) = 0; - virtual std::shared_ptr CreateMIGraphXAllocator(int16_t device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) = 0; + virtual std::shared_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) = 0; protected: ~ProviderInfo_MIGraphX() = default; // Can only be destroyed through a subclass instance From c87168e876baf3f4fbaeba08d46341ee8206bcf7 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Thu, 24 Jul 2025 01:16:34 +0200 Subject: [PATCH 16/46] Use int instead of 'bool' for ONNXRT C interface --- include/onnxruntime/core/session/onnxruntime_c_api.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index a5ed5917bd53d..6eebfa8372492 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -752,13 +752,13 @@ typedef struct OrtMIGraphXProviderOptions { int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true int migraphx_fp8_enable; // MIGraphX FP8 precision. Default 0 = false, nonzero = true int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true - int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true + int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, nonzero = true const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, noznero = true const char* migraphx_save_model_path; // migraphx model path name int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true const char* migraphx_load_model_path; // migraphx model path name - bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false + int migraphx_exhaustive_tune; // MIGraphX tuned compile. Default = false, nonzero = true /** \brief MIGraphX memory limit (To use all possible memory pass in maximum size_t) * Defaults to SIZE_MAX. From 8c0d09de9872cab977a65938922784b51d2edc81 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Mon, 28 Jul 2025 19:44:16 +0200 Subject: [PATCH 17/46] allocate memory for an option value --- include/onnxruntime/core/common/common.h | 20 ++++++++++++++++--- .../onnxruntime/core/common/string_helper.h | 6 +++++- onnxruntime/core/common/helper.cc | 4 ++-- onnxruntime/core/common/path_string.h | 14 +++++++++++++ .../migraphx/migraphx_execution_provider.cc | 8 ++++---- .../migraphx/migraphx_execution_provider.h | 2 +- .../migraphx_execution_provider_utils.h | 18 ++++++----------- .../provider_bridge_provider.cc | 4 ++-- .../shared_library/provider_interfaces.h | 5 +++++ onnxruntime/core/session/onnxruntime_c_api.cc | 13 +++++++++--- .../core/session/provider_bridge_ort.cc | 2 ++ 11 files changed, 68 insertions(+), 28 deletions(-) diff --git a/include/onnxruntime/core/common/common.h b/include/onnxruntime/core/common/common.h index adfd341451aed..820d140ccaabc 100644 --- a/include/onnxruntime/core/common/common.h +++ b/include/onnxruntime/core/common/common.h @@ -294,12 +294,26 @@ inline std::string ToUTF8String(const std::string& s) { return s; } /** * Convert a wide character string to a UTF-8 string */ -std::string ToUTF8String(const std::wstring& s); - -std::wstring ToWideString(const std::string& s); +std::string ToUTF8String(std::wstring_view s); +inline std::string ToUTF8String(const wchar_t* s) { + return ToUTF8String(std::wstring_view{s}); +} +inline std::string ToUTF8String(const std::wstring& s) { + return ToUTF8String(std::wstring_view{s}); +} +std::wstring ToWideString(std::string_view s); +inline std::wstring ToWideString(const char* s) { + return ToWideString(std::string_view{s}); +} +inline std::wstring ToWideString(const std::string& s) { + return ToWideString(std::string_view{s}); +} inline std::wstring ToWideString(const std::wstring& s) { return s; } +inline std::wstring ToWideString(std::wstring_view s) { return std::wstring{s}; } #else inline std::string ToWideString(const std::string& s) { return s; } +inline std::string ToWideString(const char* s) { return s; } +inline std::string ToWideString(std::string_view s) { return std::string{s}; } #endif constexpr size_t kMaxStrLen = 4096; diff --git a/include/onnxruntime/core/common/string_helper.h b/include/onnxruntime/core/common/string_helper.h index 1304303132d5a..c0b331cb8e9a8 100644 --- a/include/onnxruntime/core/common/string_helper.h +++ b/include/onnxruntime/core/common/string_helper.h @@ -7,5 +7,9 @@ // forward declaration struct OrtAllocator; namespace onnxruntime { -char* StrDup(const std::string& str, OrtAllocator* allocator); +char* StrDup(std::string_view str, OrtAllocator* allocator); +inline char* StrDup(const std::string& str, OrtAllocator* allocator) { + return StrDup(std::string_view{str}, allocator); +} +wchar_t* StrDup(std::wstring_view str, OrtAllocator* allocator); } // namespace onnxruntime diff --git a/onnxruntime/core/common/helper.cc b/onnxruntime/core/common/helper.cc index 6a52db73df106..07cd1672b27c1 100644 --- a/onnxruntime/core/common/helper.cc +++ b/onnxruntime/core/common/helper.cc @@ -18,7 +18,7 @@ namespace onnxruntime { #ifdef _WIN32 -std::string ToUTF8String(const std::wstring& s) { +std::string ToUTF8String(std::wstring_view s) { if (s.size() >= static_cast(std::numeric_limits::max())) ORT_THROW("length overflow"); @@ -33,7 +33,7 @@ std::string ToUTF8String(const std::wstring& s) { return ret; } -std::wstring ToWideString(const std::string& s) { +std::wstring ToWideString(std::string_view s) { if (s.size() >= static_cast(std::numeric_limits::max())) ORT_THROW("length overflow"); diff --git a/onnxruntime/core/common/path_string.h b/onnxruntime/core/common/path_string.h index 6cfb327cce08a..4ca326d76a37d 100644 --- a/onnxruntime/core/common/path_string.h +++ b/onnxruntime/core/common/path_string.h @@ -40,6 +40,12 @@ inline PathString ToPathString(const PathString& s) { static_assert(std::is_same::value, "PathString is not std::wstring!"); +inline PathString ToPathString(std::string_view s) { + return ToWideString(s); +} +inline PathString ToPathString(const char* s) { + return ToWideString(s); +} inline PathString ToPathString(const std::string& s) { return ToWideString(s); } @@ -56,6 +62,14 @@ inline std::string PathToUTF8String(const PathString& s) { static_assert(std::is_same::value, "PathString is not std::string!"); +inline PathString ToPathString(const char* s) { + return s; +} + +inline PathString ToPathString(std::string_view s) { + return PathString{s}; +} + inline PathChar ToLowerPathChar(PathChar c) { return std::tolower(c); } diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index dd3eede481b9b..765f6710f206f 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -171,9 +171,9 @@ void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecut // Load INT8 calibration table std::unordered_map dynamic_range_map; if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { - const std::string calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); + const auto calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { - throw std::runtime_error("Session Failed to read INT8 calibration table " + calibration_cache_path); + throw std::runtime_error("Session Failed to read INT8 calibration table " + calibration_cache_path.string()); } } @@ -267,9 +267,9 @@ void MIGraphXExecutionProvider::get_flags_from_env() { // Load INT8 calibration table std::unordered_map dynamic_range_map; if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { - const std::string calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); + const auto calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { - throw std::runtime_error("ENV Failed to read calibration table " + calibration_cache_path); + throw std::runtime_error("ENV Failed to read calibration table " + calibration_cache_path.string()); } } diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 9a54c3045606b..6e230641aece0 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -113,7 +113,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { std::string int8_calibration_cache_name_; bool int8_calibration_cache_available_ = false; bool int8_use_native_migraphx_calibration_table_ = false; - std::string calibration_cache_path_; + std::filesystem::path calibration_cache_path_{}; std::unordered_map dynamic_range_map_; bool save_compiled_model_ = false; std::string save_compiled_path_; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h index 0282d3eecf529..a70b848ed27e9 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h @@ -187,12 +187,12 @@ inline float ConvertSinglePrecisionIEEE754ToFloat(uint32_t input) { * Taken from the tensorRT EP to allow MIGraphX EP to reuse calibration tables for existing models * */ -inline bool ReadDynamicRange(const std::string file_name, +inline bool ReadDynamicRange(const std::filesystem::path& filename, const bool is_calibration_table, std::unordered_map& dynamic_range_map) { - std::ifstream infile(file_name, std::ios::binary | std::ios::in); - if (!infile) { + std::ifstream infile{filename, std::ios::binary | std::ios::in}; + if (!infile.good()) { return false; } @@ -218,7 +218,7 @@ inline bool ReadDynamicRange(const std::string file_name, dynamic_range_map[tensor_name] = dynamic_range; } } else { - throw std::runtime_error("This is not a TensorRT generated calibration table " + file_name); + throw std::runtime_error("This is not a TensorRT generated calibration table " + filename.string()); } } } else { @@ -243,14 +243,8 @@ inline bool ReadDynamicRange(const std::string file_name, * Get cache by name * */ -inline std::string GetCachePath(const std::string& root, const std::string& name) { - if (root.empty()) { - return name; - } else { - fs::path path = root; - path.append(name); - return path.string(); - } +inline std::filesystem::path GetCachePath(const std::filesystem::path& root, std::string_view name) { + return root.empty() ? std::filesystem::path{ToPathString(name)} : root / ToPathString(name); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 031a4df59d83f..765701689511b 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -790,11 +790,11 @@ Status LoadDynamicLibrary(onnxruntime::PathString library_name) { #endif #ifdef _WIN32 -std::string ToUTF8String(const std::wstring& s) { +std::string ToUTF8String(std::wstring_view s) { return g_host->ToUTF8String(s); } -std::wstring ToWideString(const std::string& s) { +std::wstring ToWideString(std::string_view s) { return g_host->ToWideString(s); } #endif // _WIN32 diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 5c9c1a0ae163f..b6c4bccfe4e00 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -1354,6 +1354,11 @@ struct ProviderHost { virtual std::unique_ptr ModelMetadefIdGenerator__construct() = 0; virtual void ModelMetadefIdGenerator__operator_delete(ModelMetadefIdGenerator* p) = 0; virtual int ModelMetadefIdGenerator__GenerateId(const ModelMetadefIdGenerator* p, const GraphViewer& graph_viewer, HashValue& model_hash) = 0; + +#ifdef _WIN32 + virtual std::string ToUTF8String(std::wstring_view s) = 0; + virtual std::wstring ToWideString(std::string_view s) = 0; +#endif }; #if defined(_MSC_VER) && !defined(__clang__) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 37f4fe7312bb4..62ce5adc97b00 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1378,9 +1378,16 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetOverridableInitializerTypeInfo, _In_ cons return GetNodeDefTypeInfoHelper(sess, get_overridable_initializers_fn, index, out); } -char* onnxruntime::StrDup(const std::string& str, OrtAllocator* allocator) { - char* output_string = reinterpret_cast(allocator->Alloc(allocator, str.size() + 1)); - memcpy(output_string, str.c_str(), str.size()); +char* onnxruntime::StrDup(std::string_view str, OrtAllocator* allocator) { + char* output_string = static_cast(allocator->Alloc(allocator, str.size() + 1)); + memcpy(output_string, str.data(), str.size()); + output_string[str.size()] = '\0'; + return output_string; +} + +wchar_t* onnxruntime::StrDup(std::wstring_view str, OrtAllocator* allocator) { + auto* output_string = static_cast(allocator->Alloc(allocator, str.size() + 1)); + memcpy(output_string, str.data(), str.size() * sizeof(wchar_t)); output_string[str.size()] = '\0'; return output_string; } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 01b70db6d940e..bd50cdfe6e066 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1730,6 +1730,8 @@ struct ProviderHostImpl : ProviderHost { #ifdef _WIN32 std::string ToUTF8String(const std::wstring& s) override { return onnxruntime::ToUTF8String(s); } std::wstring ToWideString(const std::string& s) override { return onnxruntime::ToWideString(s); } + std::string ToUTF8String(std::wstring_view s) override { return onnxruntime::ToUTF8String(s); } + std::wstring ToWideString(std::string_view s) override { return onnxruntime::ToWideString(s); } #endif ProviderHostCPU& GetProviderHostCPU() override { return onnxruntime::GetProviderHostCPU(); } From e4d1c3ea21337c47cb4f5e0cd931f0fcbd00664f Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Tue, 29 Jul 2025 01:16:57 +0200 Subject: [PATCH 18/46] Support for saving and loading MXR files automaticaly --- .../core/session/onnxruntime_c_api.h | 3 +- .../migraphx/migraphx_execution_provider.cc | 153 +++++++++++------- .../migraphx/migraphx_execution_provider.h | 13 +- .../migraphx_execution_provider_info.cc | 9 +- .../migraphx_execution_provider_info.h | 16 +- .../migraphx_execution_provider_utils.h | 80 +++++++++ .../core/providers/migraphx/migraphx_inc.h | 1 + .../migraphx/migraphx_provider_factory.cc | 17 +- .../migraphx/migraphx_stream_handle.cc | 9 +- .../migraphx/migraphx_stream_handle.h | 6 +- .../python/onnxruntime_pybind_state.cc | 50 ++---- onnxruntime/test/util/default_providers.cc | 12 +- 12 files changed, 218 insertions(+), 151 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 6eebfa8372492..f78e1c86fd65a 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -773,7 +773,8 @@ typedef struct OrtMIGraphXProviderOptions { * \note If a ::OrtArenaCfg has been applied, it will override this field */ int migraphx_arena_extend_strategy; - int migraphx_bf16_enable; // MIGraphX BF16 precision. Default 0 = false, nonzero = true + int migraphx_bf16_enable; // MIGraphX BF16 precision. Default 0 = false, nonzero = true + const ORTCHAR_T* migraphx_cache_dir; // MIGraphX model cache directory } OrtMIGraphXProviderOptions; /** \brief OpenVINO Provider Options diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 765f6710f206f..591caafd9657d 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -130,6 +130,7 @@ MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info) { // Set GPU device to be used HIP_CALL_THROW(hipSetDevice(info_.device_id)); + HIP_CALL_THROW(hipGetDeviceProperties(&device_prop_, info.device_id)); t_ = migraphx::target(info.target_device.c_str()); // Quantization @@ -178,10 +179,7 @@ void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecut } // Save/load migraphx compiled models - save_compiled_model_ = info.save_compiled_model; - save_compiled_path_ = info.save_model_file; - load_compiled_model_ = info.load_compiled_model; - load_compiled_path_ = info.load_model_file; + model_cache_path_ = info.model_cache_dir; exhaustive_tune_ = info.exhaustive_tune; @@ -274,28 +272,11 @@ void MIGraphXExecutionProvider::get_flags_from_env() { } // Save/load migraphx compiled models - const std::string save_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSaveCompiledModel); - if (!save_comp_model_env.empty()) { - save_compiled_model_ = (std::stoi(save_comp_model_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_SAVE_COMPILED_MODEL: " << save_compiled_model_; - } - - const std::string save_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSavedModelPath); - if (save_compiled_model_ && !save_model_path_env.empty()) { - save_compiled_path_ = save_model_path_env; - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_SAVE_COMPILED_PATH: " << save_compiled_path_; - } - - const std::string load_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadCompiledModel); - if (!load_comp_model_env.empty()) { - load_compiled_model_ = (std::stoi(load_comp_model_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_LOAD_COMPILED_MODEL: " << load_compiled_model_; - } - - const std::string load_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadModelPath); - if (load_compiled_model_ && !load_model_path_env.empty()) { - load_compiled_path_ = load_model_path_env; - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_LOAD_COMPILED_PATH: " << load_compiled_path_; + const auto model_cache_path_env = GetEnvironmentVar(migraphx_env_vars::kModelCachePath); + if (!model_cache_path_env.empty()) { + model_cache_path_ = GetEnvironmentVar(migraphx_env_vars::kModelCachePath); + LOGS_DEFAULT(INFO) << "\n" + << migraphx_env_vars::kModelCachePath << ": " << model_cache_path_; } // dump unsupported ops @@ -314,20 +295,17 @@ void MIGraphXExecutionProvider::get_flags_from_env() { } void MIGraphXExecutionProvider::print_migraphx_ep_flags() { - LOGS_DEFAULT(WARNING) << "\n device_id: " << info_.device_id - << "\n migraphx_fp16_enable: " << fp16_enable_ - << "\n migraphx_bf16_enable: " << bf16_enable_ - << "\n migraphx_fp8_enable: " << fp8_enable_ - << "\n migraphx_int8_enable: " << int8_enable_ + LOGS_DEFAULT(VERBOSE) << "\n " << migraphx_provider_option::kDeviceId << ": " << info_.device_id + << "\n " << migraphx_provider_option::kFp16Enable << ": " << fp16_enable_ + << "\n " << migraphx_provider_option::kBf16Enable << ": " << bf16_enable_ + << "\n " << migraphx_provider_option::kFp8Enable << ": " << fp8_enable_ + << "\n " << migraphx_provider_option::kInt8Enable << ": " << int8_enable_ << "\n dump_model_ops: " << dump_model_ops_ - << "\n exhaustive_tune: " << exhaustive_tune_ - << "\n migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_ + << "\n " << migraphx_provider_option::kExhaustiveTune << ": " << exhaustive_tune_ + << "\n " << migraphx_provider_option::kInt8CalibTable << ": " << int8_calibration_cache_name_ << "\n int8_calibration_cache_available: " << int8_calibration_cache_available_ - << "\n use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_ - << "\n migraphx_save_compiled_model: " << save_compiled_model_ - << "\n migraphx_save_compiled_model_path: " << save_compiled_path_ - << "\n migraphx_load_compiled_model: " << load_compiled_model_ - << "\n migraphx_load_compiled_model_path: " << load_compiled_path_; + << "\n " << migraphx_provider_option::kInt8UseNativeCalibTable << ": " << int8_use_native_migraphx_calibration_table_ + << "\n " << migraphx_provider_option::kModelCacheDir << ": " << model_cache_path_; } AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, @@ -1285,28 +1263,24 @@ bool get_input_output_names(const GraphViewer& graph, // Attempt to load a model and catch any exceptions on load fail. // Useful to default to EP to trigger the compile if file doesn't exist or loading fails. -bool load_precompiled_model(migraphx::program& prog, bool load_enable, std::string path) { - try { - if (load_enable) { - LOGS_DEFAULT(VERBOSE) << "Attempting to load model at:" << path; - prog = migraphx::load(path.c_str()); - LOGS_DEFAULT(VERBOSE) << "load model : Success"; - return true; - } else { - return false; - } - } catch (...) { - return false; +bool load_precompiled_model(migraphx::program& prog, const std::filesystem::path& path) try { + if (!path.empty() && exists(path)) { + LOGS_DEFAULT(VERBOSE) << "Attempting to load model at:" << path.string(); + prog = migraphx::load(path.string().c_str()); + LOGS_DEFAULT(VERBOSE) << "load model : Success"; + return true; } return false; +} catch (...) { + return false; } -void save_compiled_model(migraphx::program& prog, bool save_enable, std::string out_path) { - if (save_enable) { - LOGS_DEFAULT(VERBOSE) << "Model Save at " << out_path << ": Begin"; +void save_compiled_model(const migraphx::program& prog, const std::filesystem::path& path) { + if (!path.empty()) { + LOGS_DEFAULT(VERBOSE) << "Model Save at " << path.string() << ": Begin"; migraphx::file_options fo; fo.set_file_format("msgpack"); - migraphx::save(prog, out_path.c_str(), fo); + save(prog, path.string().c_str(), fo); LOGS_DEFAULT(VERBOSE) << "Model Save: Complete"; } } @@ -1381,6 +1355,27 @@ void compile_program(migraphx::program& prog, LOGS_DEFAULT(WARNING) << "Model Compile: Complete"; } +std::string to_hex(const uint64_t v) { + std::array s{}; + auto [ptr, _] = std::to_chars(s.data(), s.data() + s.size(), v, 16); + return std::string{s.data(), ptr}; +} + +template +std::string make_hash(T v) { + std::array temp{}; + MurmurHash3::x86_128(v.data(), gsl::narrow_cast(v.size()), temp[0], temp.data()); + return to_hex(temp[0] | static_cast(temp[1]) << 32); +} + +template <> +std::string make_hash(const char* v) { + return make_hash(std::string_view{v}); +} + +constexpr std::uint64_t MIGraphX_Version = + ((MIGRAPHX_VERSION_MAJOR << 16) | (MIGRAPHX_VERSION_MINOR << 8) | MIGRAPHX_VERSION_PATCH); + Status MIGraphXExecutionProvider::Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) { migraphx::onnx_options options; @@ -1388,6 +1383,33 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& for (const auto& fused_node_graph : fused_nodes) { const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; const Node& fused_node = fused_node_graph.fused_node; + + std::filesystem::path model_cache_file; + auto mxr_filename_prefix = to_hex(MIGraphX_Version) + "-" + GenerateGraphId(graph_body_viewer) + "-" + make_hash(std::string_view(device_prop_.gcnArchName)) + "-"; + + // Get model input names (only first layer) + const Graph* cur_graph = &graph_body_viewer.GetGraph(); + while (cur_graph->IsSubgraph()) { + cur_graph = cur_graph->ParentGraph(); + } + const Graph& main_graph = *cur_graph; + const auto& input_tensor = main_graph.GetInputs(); + for (auto i : input_tensor) { + session_input_names.insert(i->Name()); + } + + // empty cache path means the MXR caching is disabled - always compile + if (!model_cache_path_.empty()) { + std::vector input_shapes; + for (std::size_t i = 0; i < session_input_names.size(); ++i) { + auto tensor_shape = input_tensor[i]->Shape(); + for (int j = 1; j < tensor_shape->dim_size(); ++j) { + input_shapes.push_back(tensor_shape->dim(j).dim_value()); + } + } + model_cache_file = model_cache_path_ / (mxr_filename_prefix + make_hash(input_shapes) + ".mxr"); + } + // map parameter input name to index std::unordered_map input_name_index; const auto& input_defs = fused_node.InputDefs(); @@ -1418,7 +1440,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::program prog; if (!no_input_shape) { - if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { + if (!load_precompiled_model(prog, model_cache_file)) { LOGS_DEFAULT(VERBOSE) << "No input shapes detected quantizing model"; #ifndef ENABLE_TRAINING_CORE #ifdef HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH @@ -1431,7 +1453,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& calibrate_and_quantize(prog, t_, quant_params, fp16_enable_, bf16_enable_, int8_enable_, fp8_enable_, int8_calibration_cache_available_, dynamic_range_map_); compile_program(prog, t_, exhaustive_tune_); - save_compiled_model(prog, save_compiled_model_, save_compiled_path_); + save_compiled_model(prog, model_cache_file); } auto prog_output_shapes = prog.get_output_shapes(); @@ -1454,8 +1476,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, map_no_input_shape_[context->node_name], fp16_enable_, bf16_enable_, fp8_enable_, int8_enable_, int8_calibration_cache_available_, dynamic_range_map_, - save_compiled_model_, save_compiled_path_, - load_compiled_model_, load_compiled_path_, dump_model_ops_}; + model_cache_path_.string(), dump_model_ops_}; *state = p.release(); return 0; }; @@ -1465,7 +1486,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& delete static_cast(state); }; - compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) { + compute_info.compute_func = [this, mxr_filename_prefix](FunctionState state, const OrtApi* api, OrtKernelContext* context) { Ort::KernelContext ctx(context); MIGraphXFuncState* mgx_state = reinterpret_cast(state); @@ -1486,6 +1507,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // from input data bool input_shape_match = true; migraphx::program_parameter_shapes param_shapes; + std::vector input_shapes; + if (no_input_shape) { LOGS_DEFAULT(VERBOSE) << "Missing input shape setting input parameters again"; for (auto& it : map_input_name_index) { @@ -1525,6 +1548,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& cmp_options.set_input_parameter_shape(name, ort_lens); input_shape_match = false; } + input_shapes.insert(input_shapes.end(), tensor_shape.begin(), tensor_shape.end()); } } } @@ -1533,8 +1557,13 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // input shapes are different, needs to re-parse onnx and // re-compile the program if (!input_shape_match) { - if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { - LOGS_DEFAULT(VERBOSE) << "Input shape mismatch detected. Recompiling" << std::endl; + std::filesystem::path model_cache_file; + // empty cache path means the MXR caching is disabled - always compile + if (!model_cache_path_.empty()) { + model_cache_file = mgx_state->model_cache_dir / (mxr_filename_prefix + make_hash(input_shapes) + ".mxr"); + } + if (!load_precompiled_model(prog, model_cache_file)) { + LOGS_DEFAULT(VERBOSE) << "Input shape mismatch detected. Recompiling"; #ifndef ENABLE_TRAINING_CORE #ifdef HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH cmp_options.set_external_data_path(model_path_.parent_path().string()); @@ -1567,7 +1596,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& calibrate_and_quantize(prog, t, quant_params, fp16_enable, bf16_enable, int8_enable, fp8_enable, int8_calibration_cache_available, map_dynamic_range); compile_program(prog, t, exhaustive_tune_); - save_compiled_model(prog, mgx_state->save_compiled_mode, mgx_state->save_compiled_path); + save_compiled_model(prog, model_cache_file); } mgx_state->prog = prog; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 6e230641aece0..d9a95041ce225 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -35,6 +35,7 @@ constexpr auto kSavedModelPath = "ORT_MIGRAPHX_SAVE_COMPILED_PATH"sv; constexpr auto kLoadCompiledModel = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL"sv; constexpr auto kLoadModelPath = "ORT_MIGRAPHX_LOAD_COMPILED_PATH"sv; constexpr auto kExhaustiveTune = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"sv; +constexpr auto kModelCachePath = "ORT_MIGRAPHX_MODEL_CACHE_PATH"sv; } // namespace migraphx_env_vars // Information to construct kernel function state. @@ -55,10 +56,7 @@ struct MIGraphXFuncState { bool int8_enable = false; bool int8_calibration_cache_available = false; std::unordered_map dynamic_range_map; - bool save_compiled_mode = false; - std::string save_compiled_path; - bool load_compiled_mode = false; - std::string load_compiled_path; + std::filesystem::path model_cache_dir; bool dump_model_ops = false; bool exhaustive_tune = false; }; @@ -115,14 +113,13 @@ class MIGraphXExecutionProvider : public IExecutionProvider { bool int8_use_native_migraphx_calibration_table_ = false; std::filesystem::path calibration_cache_path_{}; std::unordered_map dynamic_range_map_; - bool save_compiled_model_ = false; - std::string save_compiled_path_; - bool load_compiled_model_ = false; - std::string load_compiled_path_; + std::filesystem::path model_cache_path_{}; + std::set session_input_names; bool dump_model_ops_ = false; migraphx::target t_; std::mutex mgx_mu_; hipStream_t stream_ = nullptr; + hipDeviceProp_t device_prop_; bool exhaustive_tune_ = false; mutable std::filesystem::path model_path_; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index fa9e613597a46..5bc2659f09636 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -66,8 +66,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(migraphx_provider_option::kBf16Enable, info.bf16_enable) .AddAssignmentToReference(migraphx_provider_option::kFp8Enable, info.fp8_enable) .AddAssignmentToReference(migraphx_provider_option::kInt8Enable, info.int8_enable) - .AddAssignmentToReference(migraphx_provider_option::kSaveCompiledModel, info.save_compiled_model) - .AddAssignmentToReference(migraphx_provider_option::kLoadCompiledModel, info.load_compiled_model) + .AddAssignmentToReference(migraphx_provider_option::kModelCacheDir, info.model_cache_dir) .AddAssignmentToReference(migraphx_provider_option::kExhaustiveTune, info.exhaustive_tune) .AddAssignmentToReference(migraphx_provider_option::kMemLimit, info.mem_limit) .AddAssignmentToEnumReference(migraphx_provider_option::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy) @@ -86,14 +85,13 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE {std::string{migraphx_provider_option::kBf16Enable}, MakeStringWithClassicLocale(info.bf16_enable)}, {std::string{migraphx_provider_option::kFp8Enable}, MakeStringWithClassicLocale(info.fp8_enable)}, {std::string{migraphx_provider_option::kInt8Enable}, MakeStringWithClassicLocale(info.int8_enable)}, - {std::string{migraphx_provider_option::kSaveCompiledModel}, MakeStringWithClassicLocale(info.save_compiled_model)}, - {std::string{migraphx_provider_option::kLoadCompiledModel}, MakeStringWithClassicLocale(info.load_compiled_model)}, {std::string{migraphx_provider_option::kMemLimit}, MakeStringWithClassicLocale(info.mem_limit)}, {std::string{migraphx_provider_option::kGpuExternalAlloc}, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.alloc))}, {std::string{migraphx_provider_option::kGpuExternalFree}, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.free))}, {std::string{migraphx_provider_option::kGpuExternalEmptyCache}, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.empty_cache))}, {std::string{migraphx_provider_option::kArenaExtendStrategy}, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, {std::string{migraphx_provider_option::kExhaustiveTune}, MakeStringWithClassicLocale(info.exhaustive_tune)}, + {std::string{migraphx_provider_option::kModelCacheDir}, MakeStringWithClassicLocale(info.model_cache_dir)}, }; return options; } @@ -105,11 +103,10 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap {std::string{migraphx_provider_option::kBf16Enable}, MakeStringWithClassicLocale(info.migraphx_bf16_enable)}, {std::string{migraphx_provider_option::kFp8Enable}, MakeStringWithClassicLocale(info.migraphx_fp8_enable)}, {std::string{migraphx_provider_option::kInt8Enable}, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, - {std::string{migraphx_provider_option::kSaveCompiledModel}, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)}, - {std::string{migraphx_provider_option::kLoadCompiledModel}, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)}, {std::string{migraphx_provider_option::kMemLimit}, MakeStringWithClassicLocale(info.migraphx_mem_limit)}, {std::string{migraphx_provider_option::kArenaExtendStrategy}, EnumToName(arena_extend_strategy_mapping, static_cast(info.migraphx_arena_extend_strategy))}, {std::string{migraphx_provider_option::kExhaustiveTune}, MakeStringWithClassicLocale(info.migraphx_exhaustive_tune)}, + {std::string{migraphx_provider_option::kModelCacheDir}, MakeStringWithClassicLocale(info.migraphx_cache_dir)}, }; return options; } diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index 4a2f4a6521e2c..2b7547cbd3c4e 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -35,6 +35,7 @@ constexpr auto kArenaExtendStrategy = "migraphx_arena_extend_strategy"sv; constexpr auto kGpuExternalAlloc = "migraphx_external_alloc"sv; constexpr auto kGpuExternalFree = "migraphx_external_free"sv; constexpr auto kGpuExternalEmptyCache = "migraphx_external_empty_cache"sv; +constexpr auto kModelCacheDir = "migraphx_model_cache_dir"sv; } // namespace migraphx_provider_option // Information needed to construct MIGraphX execution providers. @@ -70,10 +71,7 @@ struct MIGraphXExecutionProviderInfo { bool int8_enable{false}; std::string int8_calibration_table_name{""}; bool int8_use_native_calibration_table{false}; - bool save_compiled_model{true}; - std::string save_model_file{"./compiled_model.mxr"}; - bool load_compiled_model{true}; - std::string load_model_file{"./compiled_model.mxr"}; + std::filesystem::path model_cache_dir{}; bool exhaustive_tune{false}; size_t mem_limit{std::numeric_limits::max()}; // Will be over-ridden by contents of `default_memory_arena_cfg` (if specified) @@ -99,13 +97,15 @@ struct std::hash<::onnxruntime::MIGraphXExecutionProviderInfo> { (static_cast(info.fp16_enable) << 18) ^ (static_cast(info.int8_enable) << 19) ^ (static_cast(info.int8_use_native_calibration_table) << 20) ^ - (static_cast(info.save_compiled_model) << 21) ^ - (static_cast(info.load_compiled_model) << 22) ^ - (static_cast(info.exhaustive_tune) << 23) ^ - (static_cast(info.bf16_enable) << 24); + (static_cast(info.exhaustive_tune) << 21) ^ + (static_cast(info.bf16_enable) << 22); onnxruntime::HashCombine(data, value); + onnxruntime::HashCombine(info.target_device, value); + onnxruntime::HashCombine(info.default_memory_arena_cfg, value); + onnxruntime::HashCombine(info.int8_calibration_table_name, value); + onnxruntime::HashCombine(info.model_cache_dir, value); onnxruntime::HashCombine(info.mem_limit, value); // Memory pointers diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h index a70b848ed27e9..6d239b0dd073c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h @@ -17,6 +17,7 @@ #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/execution_provider.h" #include "core/common/path_string.h" +#include "core/framework/murmurhash3.h" namespace fs = std::filesystem; @@ -247,4 +248,83 @@ inline std::filesystem::path GetCachePath(const std::filesystem::path& root, std return root.empty() ? std::filesystem::path{ToPathString(name)} : root / ToPathString(name); } +inline std::string GenerateGraphId(const GraphViewer& graph_viewer) { + HashValue model_hash; + + // find the top level graph + const Graph* cur_graph = &graph_viewer.GetGraph(); + while (cur_graph->IsSubgraph()) { + cur_graph = cur_graph->ParentGraph(); + } + + const Graph& main_graph = *cur_graph; + uint32_t hash[4] = {0, 0, 0, 0}; + + auto hash_str = [&hash](const std::string& str) { + MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); + }; + + // Use the model's file name instead of the entire path to avoid cache regeneration if a path changes + const fs::path path{main_graph.ModelPath()}; + + if (path.has_filename()) { + const auto model_name = path.filename().string(); + + LOGS_DEFAULT(INFO) << "Model name is '" << model_name << "'"; + // Ensure enough characters are hashed in case model names are too short + const size_t model_name_length = model_name.length(); + constexpr size_t hash_string_length = 500; + std::string repeat_model_name = model_name; + for (size_t i = model_name_length; i > 0 && i < hash_string_length; i += model_name_length) { + repeat_model_name += model_name; + } + hash_str(repeat_model_name); + } else { + LOGS_DEFAULT(INFO) << "Model path is empty"; + } + + // fingerprint current graph by hashing graph inputs + for (const auto* node_arg : graph_viewer.GetInputsIncludingInitializers()) { + hash_str(node_arg->Name()); + } + + // hashing outputs, inputs and inputs shapes of each node + const int number_of_ort_nodes = graph_viewer.NumberOfNodes(); + std::vector nodes_vector(number_of_ort_nodes); + std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); + const std::vector& node_index = graph_viewer.GetNodesInTopologicalOrder(); + for (const auto& index : nodes_vector) { + const auto& node = graph_viewer.GetNode(node_index[index]); + for (const auto* node_arg : node->OutputDefs()) { + if (node_arg != nullptr && node_arg->Exists()) { + hash_str(node_arg->Name()); + } + } + for (const auto* node_arg : node->InputDefs()) { + if (node_arg != nullptr && node_arg->Exists()) { + hash_str(node_arg->Name()); + if (node_arg->Shape() == nullptr) { + continue; + } + int dim_size = node_arg->Shape()->dim_size(); + for (int i = 0; i < dim_size; i++) { + hash_str(std::to_string(node_arg->Shape()->dim(i).dim_value())); + } + } + } + } + +#ifdef __linux__ + hash_str("LINUX"); +#elif defined(_WIN32) + hash_str("WINDOWS"); +#endif + + model_hash = hash[0] | static_cast(hash[1]) << 32; + + std::array s; + auto [ptr, ec] = std::to_chars(s.data(), s.data() + s.size(), model_hash, 16); + return std::string{s.data(), ptr}; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_inc.h b/onnxruntime/core/providers/migraphx/migraphx_inc.h index 2b035b20f619f..49e838747892f 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_inc.h +++ b/onnxruntime/core/providers/migraphx/migraphx_inc.h @@ -6,3 +6,4 @@ #include #include #include +#include diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 8c4bbb15d7c89..bb3f78d42815b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -100,15 +100,9 @@ struct MIGraphX_Provider : Provider { info.int8_calibration_table_name = options.migraphx_int8_calibration_table_name; } info.int8_use_native_calibration_table = options.migraphx_use_native_calibration_table != 0; - info.save_compiled_model = options.migraphx_save_compiled_model; - info.save_model_file = ""; - if (options.migraphx_save_model_path != nullptr) { - info.save_model_file = options.migraphx_save_model_path; - } - info.load_compiled_model = options.migraphx_load_compiled_model; - info.load_model_file = ""; - if (options.migraphx_load_model_path != nullptr) { - info.load_model_file = options.migraphx_load_model_path; + info.model_cache_dir = ""; + if (options.migraphx_cache_dir != nullptr) { + info.model_cache_dir = options.migraphx_cache_dir; } info.arena_extend_strategy = static_cast(options.migraphx_arena_extend_strategy); info.mem_limit = options.migraphx_mem_limit; @@ -142,10 +136,7 @@ struct MIGraphX_Provider : Provider { migx_options.migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table; - migx_options.migraphx_save_compiled_model = internal_options.save_compiled_model; - migx_options.migraphx_save_model_path = internal_options.save_model_file.c_str(); - migx_options.migraphx_load_compiled_model = internal_options.load_compiled_model; - migx_options.migraphx_load_model_path = internal_options.load_model_file.c_str(); + migx_options.migraphx_cache_dir = internal_options.model_cache_dir.native().c_str(); migx_options.migraphx_arena_extend_strategy = static_cast(internal_options.arena_extend_strategy); migx_options.migraphx_mem_limit = internal_options.mem_limit; } diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc index cf4ccdc292c17..0baa8a1c67c67 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc @@ -11,7 +11,7 @@ namespace onnxruntime { -enum class MIGraphXResource { +enum MIGraphXResource { hip_stream_t = rocm_resource_offset }; @@ -133,11 +133,8 @@ Status MIGraphXStream::CleanUpOnRunEnd() { void* MIGraphXStream::GetResource(int version, int id) const { ORT_ENFORCE(version <= MIGRAPHX_RESOURCE_VERSION, "resource version unsupported!"); - switch (id) { - case MIGraphXResource::hip_stream_t: - return GetHandle(); - default: - break; + if (id == hip_stream_t) { + return GetHandle(); } return nullptr; } diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h index b25eff1a1b9c6..132ae5fc09d13 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h @@ -18,7 +18,7 @@ struct MIGraphXStream : Stream { AllocatorPtr cpu_allocator, bool release_cpu_buffer_on_migraphx_stream); - ~MIGraphXStream(); + ~MIGraphXStream() override; std::unique_ptr CreateNotification(size_t /*num_consumers*/) override; @@ -30,7 +30,7 @@ struct MIGraphXStream : Stream { bool own_stream_{true}; - virtual void* GetResource(int version, int id) const; + void* GetResource(int version, int id) const override; private: std::vector deferred_cpu_buffers_; @@ -39,7 +39,7 @@ struct MIGraphXStream : Stream { }; void RegisterMIGraphXStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, - const OrtDevice::DeviceType device_type, + OrtDevice::DeviceType device_type, const AllocatorPtr& cpu_allocator, bool release_cpu_buffer_on_migraphx_stream, hipStream_t external_stream, diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index fc640269fa661..affbb3b79d1ef 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -954,8 +954,7 @@ static std::shared_ptr CreateExecutionProviderFactory } else if (type == kMIGraphXExecutionProvider) { #if defined(USE_MIGRAPHX) || defined(USE_MIGRAPHX_PROVIDER_INTERFACE) std::string calibration_table; - std::string save_model_path; - std::string load_model_path; + PathString model_cache_path; auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { OrtMIGraphXProviderOptions params{ @@ -965,13 +964,15 @@ static std::shared_ptr CreateExecutionProviderFactory 0, 0, nullptr, - 1, - "./compiled_model.mxr", - 1, - "./compiled_model.mxr", + 0, + nullptr, + 0, + nullptr, 1, SIZE_MAX, - 0}; + 0, + 0, + nullptr}; for (auto option : it->second) { if (option.first == "device_id") { if (!option.second.empty()) { @@ -1038,39 +1039,10 @@ static std::shared_ptr CreateExecutionProviderFactory "[ERROR] [MIGraphX] The value for the key 'migraphx_use_native_calibration_table' should be" " 'True' or 'False'. Default value is 'False'.\n"); } - } else if (option.first == migraphx_provider_option::kSaveCompiledModel) { - if (option.second == "True" || option.second == "true") { - params.migraphx_save_compiled_model = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_save_compiled_model = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_save_compiled_model' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == migraphx_provider_option::kSaveModelPath) { - if (!option.second.empty()) { - save_model_path = option.second; - params.migraphx_save_model_path = save_model_path.c_str(); - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_save_model_name' should be a " - "file name i.e. 'compiled_model.mxr'.\n"); - } - } else if (option.first == migraphx_provider_option::kLoadCompiledModel) { - if (option.second == "True" || option.second == "true") { - params.migraphx_load_compiled_model = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_load_compiled_model = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_load_compiled_model' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == migraphx_provider_option::kLoadModelPath) { + } else if (option.first == migraphx_provider_option::kModelCacheDir) { if (!option.second.empty()) { - load_model_path = option.second; - params.migraphx_load_model_path = load_model_path.c_str(); + model_cache_path = ToPathString(option.second); + params.migraphx_cache_dir = model_cache_path.c_str(); } else { ORT_THROW( "[ERROR] [MIGraphX] The value for the key 'migraphx_load_model_name' should be a " diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 2e4aa3923b649..eebe425f04c6a 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -87,13 +87,15 @@ std::unique_ptr DefaultMIGraphXExecutionProvider() { 0, 0, nullptr, - 1, - "./compiled_model.mxr", - 1, - "./compiled_model.mxr", + 0, + nullptr, + 0, + nullptr, 1, SIZE_MAX, - 0}; + 0, + 0, + nullptr}; return MIGraphXProviderFactoryCreator::Create(¶ms)->CreateProvider(); #else return nullptr; From 48989b8e7d44924e93f162fd38f0f95aa1335772 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Tue, 29 Jul 2025 16:03:08 +0200 Subject: [PATCH 19/46] The C# interface for MIGraphX execution provider --- .gitignore | 1 + .../NativeMethods.shared.cs | 237 +++++++++++++++++- .../ProviderOptions.shared.cs | 136 ++++++++++ .../SessionOptions.shared.cs | 61 ++++- .../core/session/onnxruntime_c_api.h | 82 ++++++ .../migraphx/migraphx_provider_factory.cc | 24 +- .../migraphx_provider_factory_creator.h | 2 + onnxruntime/core/session/onnxruntime_c_api.cc | 8 +- onnxruntime/core/session/ort_apis.h | 11 + .../core/session/provider_bridge_ort.cc | 133 +++++++++- .../core/session/provider_registration.cc | 64 ++++- setup.py | 1 + tools/ci_build/build.py | 5 + .../nuget/generate_nuspec_for_native_nuget.py | 45 +++- 14 files changed, 789 insertions(+), 21 deletions(-) diff --git a/.gitignore b/.gitignore index 4d0a1205b7c19..b25763334f227 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ # build, distribute, and bins (+ python proto bindings) +build.*/ build build_*/ .build_debug/* diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 8cca2b42e987a..a518020671621 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -31,10 +31,12 @@ public struct OrtApi public IntPtr CreateStatus; public IntPtr GetErrorCode; public IntPtr GetErrorMessage; + public IntPtr CreateEnv; public IntPtr CreateEnvWithCustomLogger; public IntPtr EnableTelemetryEvents; public IntPtr DisableTelemetryEvents; + public IntPtr CreateSession; public IntPtr CreateSessionFromArray; public IntPtr Run; @@ -70,6 +72,7 @@ public struct OrtApi public IntPtr SessionGetInputName; public IntPtr SessionGetOutputName; public IntPtr SessionGetOverridableInitializerName; + public IntPtr CreateRunOptions; public IntPtr RunOptionsSetRunLogVerbosityLevel; public IntPtr RunOptionsSetRunLogSeverityLevel; @@ -84,8 +87,8 @@ public struct OrtApi public IntPtr CreateTensorWithDataAsOrtValue; public IntPtr IsTensor; public IntPtr GetTensorMutableData; - public IntPtr FillStringTensor; + public IntPtr FillStringTensor; public IntPtr GetStringTensorDataLength; public IntPtr GetStringTensorContent; @@ -139,6 +142,8 @@ public struct OrtApi public IntPtr ReleaseTensorTypeAndShapeInfo; public IntPtr ReleaseSessionOptions; public IntPtr ReleaseCustomOpDomain; + // End of Version 1 - DO NOT MODIFY ABOVE (see above text for more information) + public IntPtr GetDenotationFromTypeInfo; public IntPtr CastTypeInfoToMapTypeInfo; public IntPtr CastTypeInfoToSequenceTypeInfo; @@ -148,7 +153,6 @@ public struct OrtApi public IntPtr ReleaseMapTypeInfo; public IntPtr ReleaseSequenceTypeInfo; public IntPtr SessionEndProfiling; - public IntPtr SessionGetModelMetadata; public IntPtr ModelMetadataGetProducerName; public IntPtr ModelMetadataGetGraphName; @@ -157,6 +161,7 @@ public struct OrtApi public IntPtr ModelMetadataLookupCustomMetadataMap; public IntPtr ModelMetadataGetVersion; public IntPtr ReleaseModelMetadata; + // End of Version 2 - DO NOT MODIFY ABOVE (see above text for more information) public IntPtr CreateEnvWithGlobalThreadPools; public IntPtr DisablePerSessionThreads; @@ -164,9 +169,12 @@ public struct OrtApi public IntPtr ReleaseThreadingOptions; public IntPtr ModelMetadataGetCustomMetadataMapKeys; public IntPtr AddFreeDimensionOverrideByName; + // End of Version 3 - DO NOT MODIFY ABOVE (see above text for more information) public IntPtr GetAvailableProviders; public IntPtr ReleaseAvailableProviders; + // End of Version 4 - DO NOT MODIFY ABOVE (see above text for more information) + public IntPtr GetStringTensorElementLength; public IntPtr GetStringTensorElement; public IntPtr FillStringTensorElement; @@ -191,6 +199,8 @@ public struct OrtApi public IntPtr SetGlobalIntraOpNumThreads; public IntPtr SetGlobalInterOpNumThreads; public IntPtr SetGlobalSpinControl; + // End of Version 5 - DO NOT MODIFY ABOVE (see above text for more information) + public IntPtr AddInitializer; public IntPtr CreateEnvWithCustomLoggerAndGlobalThreadPools; public IntPtr SessionOptionsAppendExecutionProvider_CUDA; @@ -199,10 +209,14 @@ public struct OrtApi public IntPtr SetGlobalDenormalAsZero; public IntPtr CreateArenaCfg; public IntPtr ReleaseArenaCfg; + // End of Version 6 - DO NOT MODIFY ABOVE (see above text for more information) + public IntPtr ModelMetadataGetGraphDescription; public IntPtr SessionOptionsAppendExecutionProvider_TensorRT; public IntPtr SetCurrentGpuDeviceId; public IntPtr GetCurrentGpuDeviceId; + // End of Version 7 - DO NOT MODIFY ABOVE (see above text for more information) + public IntPtr KernelInfoGetAttributeArray_float; public IntPtr KernelInfoGetAttributeArray_int64; public IntPtr CreateArenaCfgV2; @@ -211,6 +225,8 @@ public struct OrtApi public IntPtr ReleasePrepackedWeightsContainer; public IntPtr CreateSessionWithPrepackedWeightsContainer; public IntPtr CreateSessionFromArrayWithPrepackedWeightsContainer; + // End of Version 8 - DO NOT MODIFY ABOVE (see above text for more information) + public IntPtr SessionOptionsAppendExecutionProvider_TensorRT_V2; public IntPtr CreateTensorRTProviderOptions; public IntPtr UpdateTensorRTProviderOptions; @@ -233,6 +249,8 @@ public struct OrtApi public IntPtr GetSparseTensorValues; public IntPtr GetSparseTensorIndicesTypeShape; public IntPtr GetSparseTensorIndices; + // End of Version 9 - DO NOT MODIFY ABOVE (see above text for more information) + public IntPtr HasValue; public IntPtr KernelContext_GetGPUComputeStream; public IntPtr GetTensorMemoryInfo; @@ -245,12 +263,16 @@ public struct OrtApi public IntPtr SetGlobalCustomJoinThreadFn; public IntPtr SynchronizeBoundInputs; public IntPtr SynchronizeBoundOutputs; + // End of Version 10 - DO NOT MODIFY ABOVE (see above text for more information) + public IntPtr SessionOptionsAppendExecutionProvider_CUDA_V2; public IntPtr CreateCUDAProviderOptions; public IntPtr UpdateCUDAProviderOptions; public IntPtr GetCUDAProviderOptionsAsString; public IntPtr ReleaseCUDAProviderOptions; public IntPtr SessionOptionsAppendExecutionProvider_MIGraphX; + // End of Version 11 - DO NOT MODIFY ABOVE (see above text for more information) + public IntPtr AddExternalInitializers; public IntPtr CreateOpAttr; public IntPtr ReleaseOpAttr; @@ -260,6 +282,7 @@ public struct OrtApi public IntPtr SessionOptionsAppendExecutionProvider; public IntPtr CopyKernelInfo; public IntPtr ReleaseKernelInfo; + // End of Version 12 - DO NOT MODIFY ABOVE (see above text for more information) public IntPtr GetTrainingApi; public IntPtr SessionOptionsAppendExecutionProvider_CANN; @@ -267,6 +290,8 @@ public struct OrtApi public IntPtr UpdateCANNProviderOptions; public IntPtr GetCANNProviderOptionsAsString; public IntPtr ReleaseCANNProviderOptions; + // End of Version 13 - DO NOT MODIFY ABOVE (see above text for more information) + public IntPtr MemoryInfoGetDeviceType; public IntPtr UpdateEnvWithCustomLogLevel; public IntPtr SetGlobalIntraOpThreadAffinity; @@ -281,6 +306,8 @@ public struct OrtApi public IntPtr KernelInfoGetAttribute_tensor; public IntPtr HasSessionConfigEntry; public IntPtr GetSessionConfigEntry; + // End of Version 14 - DO NOT MODIFY ABOVE (see above text for more information) + public IntPtr SessionOptionsAppendExecutionProvider_Dnnl; public IntPtr CreateDnnlProviderOptions; public IntPtr UpdateDnnlProviderOptions; @@ -297,6 +324,8 @@ public struct OrtApi public IntPtr GetResizedStringTensorElementBuffer; public IntPtr KernelContext_GetAllocator; public IntPtr GetBuildInfoString; + // End of Version 15 - DO NOT MODIFY ABOVE (see above text for more information) + public IntPtr CreateROCMProviderOptions; public IntPtr UpdateROCMProviderOptions; public IntPtr GetROCMProviderOptionsAsString; @@ -308,6 +337,8 @@ public struct OrtApi public IntPtr UpdateCUDAProviderOptionsWithValue; public IntPtr GetCUDAProviderOptionsByName; public IntPtr KernelContext_GetResource; + // End of Version 16 - DO NOT MODIFY ABOVE (see above text for more information) + public IntPtr SetUserLoggingFunction; public IntPtr ShapeInferContext_GetInputCount; public IntPtr ShapeInferContext_GetInputTypeShape; @@ -318,25 +349,35 @@ public struct OrtApi public IntPtr SetDeterministicCompute; public IntPtr KernelContext_ParallelFor; public IntPtr SessionOptionsAppendExecutionProvider_OpenVINO_V2; + // End of Version 17 - DO NOT MODIFY ABOVE (see above text for more information) + public IntPtr SessionOptionsAppendExecutionProvider_VitisAI; public IntPtr KernelContext_GetScratchBuffer; public IntPtr KernelInfoGetAllocator; public IntPtr AddExternalInitializersFromFilesInMemory; + // End of Version 18 - DO NOT MODIFY ABOVE (see above text for more information) + // End of Version 19 - DO NOT MODIFY ABOVE (see above text for more information) + public IntPtr CreateLoraAdapter; public IntPtr CreateLoraAdapterFromArray; public IntPtr ReleaseLoraAdapter; public IntPtr RunOptionsAddActiveLoraAdapter; + public IntPtr SetEpDynamicOptions; + // End of Version 20 - DO NOT MODIFY ABOVE (see above text for more information) + public IntPtr ReleaseValueInfo; public IntPtr ReleaseNode; public IntPtr ReleaseGraph; public IntPtr ReleaseModel; + public IntPtr GetValueInfoName; public IntPtr GetValueInfoTypeInfo; + public IntPtr GetModelEditorApi; + public IntPtr CreateTensorWithDataAndDeleterAsOrtValue; public IntPtr SessionOptionsSetLoadCancellationFlag; - public IntPtr GetCompileApi; public IntPtr CreateKeyValuePairs; @@ -348,9 +389,7 @@ public struct OrtApi public IntPtr RegisterExecutionProviderLibrary; public IntPtr UnregisterExecutionProviderLibrary; - public IntPtr GetEpDevices; - public IntPtr SessionOptionsAppendExecutionProvider_V2; public IntPtr SessionOptionsSetEpSelectionPolicy; public IntPtr SessionOptionsSetEpSelectionPolicyDelegate; @@ -366,8 +405,95 @@ public struct OrtApi public IntPtr EpDevice_EpMetadata; public IntPtr EpDevice_EpOptions; public IntPtr EpDevice_Device; + public IntPtr GetEpApi; + // End of Version 22 - DO NOT MODIFY ABOVE (see above text for more information) + public IntPtr GetTensorSizeInBytes; + public IntPtr AllocatorGetStats; + + public IntPtr CreateMemoryInfo_V2; + public IntPtr MemoryInfoGetDeviceMemType; + public IntPtr MemoryInfoGetVendorId; + + public IntPtr ValueInfo_GetValueProducer; + public IntPtr ValueInfo_GetValueNumConsumers; + public IntPtr ValueInfo_GetValueConsumers; + public IntPtr ValueInfo_GetInitializerValue; + public IntPtr ValueInfo_GetExternalInitializerInfo; + public IntPtr ValueInfo_IsRequiredGraphInput; + public IntPtr ValueInfo_IsOptionalGraphInput; + public IntPtr ValueInfo_IsGraphOutput; + public IntPtr ValueInfo_IsConstantInitializer; + public IntPtr ValueInfo_IsFromOuterScope; + public IntPtr Graph_GetName; + public IntPtr Graph_GetModelPath; + public IntPtr Graph_GetOnnxIRVersion; + public IntPtr Graph_GetNumOperatorSets; + public IntPtr Graph_GetOperatorSets; + public IntPtr Graph_GetNumInputs; + public IntPtr Graph_GetInputs; + public IntPtr Graph_GetNumOutputs; + public IntPtr Graph_GetOutputs; + public IntPtr Graph_GetNumInitializers; + public IntPtr Graph_GetInitializers; + public IntPtr Graph_GetNumNodes; + public IntPtr Graph_GetNodes; + public IntPtr Graph_GetParentNode; + public IntPtr Graph_GetGraphView; + public IntPtr Node_GetId; + public IntPtr Node_GetName; + public IntPtr Node_GetOperatorType; + public IntPtr Node_GetDomain; + public IntPtr Node_GetSinceVersion; + public IntPtr Node_GetNumInputs; + public IntPtr Node_GetInputs; + public IntPtr Node_GetNumOutputs; + public IntPtr Node_GetOutputs; + public IntPtr Node_GetNumImplicitInputs; + public IntPtr Node_GetImplicitInputs; + public IntPtr Node_GetNumAttributes; + public IntPtr Node_GetAttributes; + public IntPtr Node_GetAttributeByName; + public IntPtr OpAttr_GetType; + public IntPtr OpAttr_GetName; + public IntPtr Node_GetNumSubgraphs; + public IntPtr Node_GetSubgraphs; + public IntPtr Node_GetGraph; + public IntPtr Node_GetEpName; + public IntPtr ReleaseExternalInitializerInfo; + public IntPtr ExternalInitializerInfo_GetFilePath; + public IntPtr ExternalInitializerInfo_GetFileOffset; + public IntPtr ExternalInitializerInfo_GetByteSize; + + public IntPtr GetRunConfigEntry; + + public IntPtr EpDevice_MemoryInfo; + + public IntPtr CreateSharedAllocator; + public IntPtr GetSharedAllocator; + public IntPtr ReleaseSharedAllocator; + + public IntPtr GetTensorData; + + public IntPtr GetSessionOptionsConfigEntries; + + public IntPtr SessionGetMemoryInfoForInputs; + public IntPtr SessionGetMemoryInfoForOutputs; + public IntPtr SessionGetEpDeviceForInputs; + + public IntPtr CreateSyncStreamForEpDevice; + public IntPtr SyncStream_GetHandle; + public IntPtr ReleaseSyncStream; + + public IntPtr CopyTensors; + + public IntPtr CreateMIGraphXProviderOptions; + public IntPtr UpdateMIGraphXProviderOptions; + public IntPtr GetMIGraphXProviderOptionsAsString; + public IntPtr ReleaseMIGraphXProviderOptions; + public IntPtr UpdateMIGraphXProviderOptionsWithValue; + public IntPtr GetMIGraphXProviderOptionsByName; } internal static class NativeMethods @@ -611,6 +737,18 @@ static NativeMethods() OrtUpdateROCMProviderOptions = (DOrtUpdateROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.UpdateROCMProviderOptions, typeof(DOrtUpdateROCMProviderOptions)); OrtGetROCMProviderOptionsAsString = (DOrtGetROCMProviderOptionsAsString)Marshal.GetDelegateForFunctionPointer(api_.GetROCMProviderOptionsAsString, typeof(DOrtGetROCMProviderOptionsAsString)); OrtReleaseROCMProviderOptions = (DOrtReleaseROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseROCMProviderOptions, typeof(DOrtReleaseROCMProviderOptions)); + SessionOptionsAppendExecutionProvider_MIGraphX = (DSessionOptionsAppendExecutionProvider_MIGraphX)Marshal.GetDelegateForFunctionPointer( + api_.SessionOptionsAppendExecutionProvider_MIGraphX, typeof(DSessionOptionsAppendExecutionProvider_MIGraphX)); + OrtCreateMIGraphXProviderOptions = (DOrtCreateMIGraphXProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.CreateMIGraphXProviderOptions, typeof(DOrtCreateMIGraphXProviderOptions)); + OrtUpdateMIGraphXProviderOptions = (DOrtUpdateMIGraphXProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.UpdateMIGraphXProviderOptions, typeof(DOrtUpdateMIGraphXProviderOptions)); + OrtGetMIGraphXProviderOptionsAsString = (DOrtGetMIGraphXProviderOptionsAsString)Marshal.GetDelegateForFunctionPointer(api_.GetMIGraphXProviderOptionsAsString, typeof(DOrtGetMIGraphXProviderOptionsAsString)); + OrtReleaseMIGraphXProviderOptions = (DOrtReleaseMIGraphXProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseMIGraphXProviderOptions, typeof(DOrtReleaseMIGraphXProviderOptions)); + OrtUpdateMIGraphXProviderOptionsWithValue = + (DOrtUpdateMIGraphXProviderOptionsWithValue)Marshal.GetDelegateForFunctionPointer( + api_.UpdateMIGraphXProviderOptionsWithValue, typeof(DOrtUpdateMIGraphXProviderOptionsWithValue)); + OrtGetMIGraphXProviderOptionsByName = + (DOrtGetMIGraphXProviderOptionsByName)Marshal.GetDelegateForFunctionPointer( + api_.GetMIGraphXProviderOptionsByName, typeof(DOrtGetMIGraphXProviderOptionsByName)); OrtCreateAndRegisterAllocatorV2 = (DCreateAndRegisterAllocatorV2)Marshal.GetDelegateForFunctionPointer(api_.CreateAndRegisterAllocatorV2, typeof(DCreateAndRegisterAllocatorV2)); OrtRunAsync = (DOrtRunAsync)Marshal.GetDelegateForFunctionPointer(api_.RunAsync, typeof(DOrtRunAsync)); CreateLoraAdapter = (DCreateLoraAdapter)Marshal.GetDelegateForFunctionPointer(api_.CreateLoraAdapter, @@ -921,6 +1059,80 @@ internal class NativeLib #endregion +#region Provider Options API + /// + /// Creates native OrtMIGraphXProviderOptions instance + /// + /// (output) native instance of OrtMIGraphXProviderOptions + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCreateMIGraphXProviderOptions( + out IntPtr /*(OrtMIGraphXProviderOptions**)*/ migraphxProviderOptionsInstance); + public static DOrtCreateMIGraphXProviderOptions OrtCreateMIGraphXProviderOptions; + + /// + /// Updates native OrtMIGraphXProviderOptions instance using given key/value pairs + /// + /// native instance of OrtMIGraphXProviderOptions + /// configuration keys of OrtMIGraphXProviderOptions + /// configuration values of OrtMIGraphXProviderOptions + /// number of configuration keys + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtUpdateMIGraphXProviderOptions( + IntPtr /*(OrtMIGraphXProviderOptions*)*/ migraphxProviderOptionsInstance, + IntPtr[] /*(const char* const *)*/ providerOptionsKeys, + IntPtr[] /*(const char* const *)*/ providerOptionsValues, + UIntPtr /*(size_t)*/ numKeys); + public static DOrtUpdateMIGraphXProviderOptions OrtUpdateMIGraphXProviderOptions; + + /// + /// Get native OrtMIGraphXProviderOptions in serialized string + /// + /// instance of OrtAllocator + /// is a UTF-8 null terminated string allocated using 'allocator' + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtGetMIGraphXProviderOptionsAsString( + IntPtr /*(OrtMIGraphXProviderOptions**)*/ migraphxProviderOptionsInstance, + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(char**)*/ ptr); + public static DOrtGetMIGraphXProviderOptionsAsString OrtGetMIGraphXProviderOptionsAsString; + + /// + /// Releases native OrtMIGraphXProviderOptions instance + /// + /// native instance of OrtMIGraphXProviderOptions to be released + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtReleaseMIGraphXProviderOptions(IntPtr /*(OrtMIGraphXProviderOptions*)*/ migraphxProviderOptionsInstance); + public static DOrtReleaseMIGraphXProviderOptions OrtReleaseMIGraphXProviderOptions; + + /// + /// Update native OrtMIGraphXProviderOptions with value + /// + /// native instance of OrtMIGraphXProviderOptions to be released + /// configuration key of OrtMIGraphXProviderOptions + /// configuration value of OrtMIGraphXProviderOptions + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr DOrtUpdateMIGraphXProviderOptionsWithValue( + IntPtr /*(OrtMIGraphXProviderOptions**)*/ migraphxProviderOptionsInstance, + IntPtr /*(char*)*/ providerOptionsKey, + IntPtr /*(char*)*/ providerOptionsValue); + public static DOrtUpdateMIGraphXProviderOptionsWithValue OrtUpdateMIGraphXProviderOptionsWithValue; + + /// + /// Get native OrtMIGraphXProviderOptions value by name + /// + /// native instance of OrtMIGraphXProviderOptions to be released + /// configuration key of OrtMIGraphXProviderOptions + /// configuration value of OrtMIGraphXProviderOptions to return + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr DOrtGetMIGraphXProviderOptionsByName( + IntPtr /*(OrtMIGraphXProviderOptions**)*/ migraphxProviderOptionsInstance, + IntPtr /*(char*)*/ providerOptionsKey, + out IntPtr /*(char**)*/ providerOptionsValue); + public static DOrtGetMIGraphXProviderOptionsByName OrtGetMIGraphXProviderOptionsByName; + + +#endregion + #region Status API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate ErrorCode DOrtGetErrorCode(IntPtr /*(OrtStatus*)*/ status); @@ -1287,6 +1499,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)] public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_MIGraphX(IntPtr /*(OrtSessionOptions*)*/ options, int device_id); + + [DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)] + public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_MIGraphX(IntPtr /*(OrtSessionOptions*)*/ options, int use_arena, int device_id); #endif /// /// Append a TensorRT EP instance (configured based on given provider options) to the native OrtSessionOptions instance @@ -1348,6 +1563,18 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DSessionOptionsAppendExecutionProvider_ROCM SessionOptionsAppendExecutionProvider_ROCM; + /// + /// Append a MIGraphX EP instance (configured based on given provider options) to the native OrtSessionOptions instance + /// + /// Native OrtSessionOptions instance + /// Native OrtMIGraphXProviderOptions instance + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DSessionOptionsAppendExecutionProvider_MIGraphX( + IntPtr /*(OrtSessionOptions*)*/ options, + IntPtr /*(const OrtMIGraphXProviderOptions*)*/ migraphxProviderOptions); + + public static DSessionOptionsAppendExecutionProvider_MIGraphX SessionOptionsAppendExecutionProvider_MIGraphX; + /// /// Free Dimension override (by denotation) /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs index 1b9cd7572170b..335b4ef8b3f65 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs @@ -291,6 +291,142 @@ protected override bool ReleaseHandle() } +/// + /// Holds the options for configuring an MIGraphX Execution Provider instance + /// + public class OrtMIGraphXProviderOptions : SafeHandle + { + internal IntPtr Handle + { + get + { + return handle; + } + } + + public int DeviceId + { + get { return _deviceId; } + set + { + UpdateProviderOptionWithValue(_deviceIdPtr, value.ToString()); + _deviceId = value; + } + } + private IntPtr _deviceIdPtr = Marshal.StringToHGlobalAnsi("device_id"); + private int _deviceId = 0; + + public string ModelCacheDir + { + get { return _modelCacheDir; } + set + { + UpdateProviderOptionWithValue(_modelCacheDirPtr, value); + _modelCacheDir = value; + } + } + + private IntPtr _modelCacheDirPtr = Marshal.StringToHGlobalAnsi("migraphx_model_cache_dir"); + private string _modelCacheDir = ""; + + #region Constructor + + /// + /// Constructs an empty OrtMIGraphXProviderOptions instance + /// + public OrtMIGraphXProviderOptions() : base(IntPtr.Zero, true) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateMIGraphXProviderOptions(out handle)); + } + + #endregion + + #region Finalizer + + ~OrtMIGraphXProviderOptions() + { + Marshal.FreeHGlobal(_deviceIdPtr); + Marshal.FreeHGlobal(_modelCacheDirPtr); + } + + #endregion + + #region Public Methods + + /// + /// Get MIGraphX EP provider options + /// + /// return C# UTF-16 encoded string + public string GetOptions() + { + var allocator = OrtAllocator.DefaultInstance; + // Process provider options string + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetMIGraphXProviderOptionsAsString(handle, + allocator.Pointer, out IntPtr providerOptions)); + return NativeOnnxValueHelper.StringFromNativeUtf8(providerOptions, allocator); + } + + /// + /// Updates the configuration knobs of OrtMIGraphXProviderOptions that will eventually be used to configure a MIGraphX EP + /// + /// Array of keys to set that correspond with values. + /// Array of values to set that correspond with keys. + /// The number of key/value pairs in the arrays. + private static IntPtr UpdateMIGraphXProviderOptions(IntPtr handle, IntPtr[] keys, IntPtr[] values, UIntPtr count) + { + return NativeMethods.OrtUpdateMIGraphXProviderOptions(handle, keys, values, count); + } + + /// + /// Updates the configuration knobs of OrtMIGraphXProviderOptions that will eventually be used to configure a MIGraphX EP + /// + /// key/value pairs used to configure a MIGraphX Execution Provider + public void UpdateOptions(Dictionary providerOptions) + { + ProviderOptionsUpdater.Update(providerOptions, handle, UpdateMIGraphXProviderOptions); + } + + #endregion + + #region Public Properties + + /// + /// Overrides SafeHandle.IsInvalid + /// + /// returns true if handle is equal to Zero + public override bool IsInvalid { get { return handle == IntPtr.Zero; } } + + #endregion + + #region Private Methods + + private void UpdateProviderOptionWithValue(IntPtr key, string value) + { + IntPtr valuePtr = Marshal.StringToHGlobalAnsi(value); + var nativeStatus = NativeMethods.OrtUpdateMIGraphXProviderOptionsWithValue(handle, key, valuePtr); + Marshal.FreeHGlobal(valuePtr); + NativeApiStatus.VerifySuccess(nativeStatus); + } + + #endregion + + #region SafeHandle + /// + /// Overrides SafeHandle.ReleaseHandle() to properly dispose of + /// the native instance of OrtMIGraphXProviderOptions + /// + /// always returns true + protected override bool ReleaseHandle() + { + NativeMethods.OrtReleaseMIGraphXProviderOptions(handle); + handle = IntPtr.Zero; + return true; + } + + #endregion + } + + /// /// This helper class contains methods to handle values of provider options /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index 6e325f7fe9646..c85cd64efeec0 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -58,6 +58,9 @@ public class SessionOptions : SafeHandle private static string[] cudaDelayLoadedLibs = { }; private static string[] trtDelayLoadedLibs = { }; + // Delay-loaded MIGraphX DLLs. Currently, delayload is disabled. See cmake/CMakeLists.txt for more information. + private static string[] migxDelayLoadedLibs = { }; + #region Constructor and Factory methods /// @@ -205,6 +208,28 @@ public static SessionOptions MakeSessionOptionWithRocmProvider(OrtROCMProviderOp throw; } } + + /// + /// A helper method to construct a SessionOptions object for MIGraaphX execution provider. + /// Use only if MIGraphX is installed and you have the onnxruntime package specific to this Execution Provider. + /// + /// MIGraphX EP provider options + /// A SessionsOptions() object configured for execution on provider options + public static SessionOptions MakeSessionOptionWithMIGraphXProvider(OrtMIGraphXProviderOptions migxProviderOptions) + { + CheckMIGraphXExecutionProviderDLLs(); + SessionOptions options = new SessionOptions(); + try + { + options.AppendExecutionProvider_MIGraphX(migxProviderOptions); + return options; + } + catch (Exception) + { + options.Dispose(); + throw; + } + } #endregion #region ExecutionProviderAppends @@ -347,12 +372,25 @@ public void AppendExecutionProvider_ROCm(OrtROCMProviderOptions rocmProviderOpti public void AppendExecutionProvider_MIGraphX(int deviceId = 0) { #if __MOBILE__ - throw new NotSupportedException($"The MIGraphX Execution Provider is not supported in this build"); + throw new NotSupportedException("The MIGraphX Execution Provider is not supported in this build"); #else NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_MIGraphX(handle, deviceId)); #endif } + /// + /// Use only if you have the onnxruntime package specific to this Execution Provider. + /// + /// device identification + public void AppendExecutionProvider_MIGraphX(OrtMIGraphXProviderOptions migraphxProviderOptions) + { +#if __MOBILE__ + throw new NotSupportedException($"The AMD Nitris Execution Provider is not supported in this build"); +#else + NativeApiStatus.VerifySuccess(NativeMethods.SessionOptionsAppendExecutionProvider_MIGraphX(handle, migraphxProviderOptions.Handle)); +#endif + } + /// /// Use only if you have the onnxruntime package specific to this Execution Provider. /// @@ -1125,6 +1163,27 @@ private static bool CheckRocmExecutionProviderDLLs() return true; } + private static bool CheckMIGraphXExecutionProviderDLLs() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + foreach (var dll in migxDelayLoadedLibs) + { + IntPtr handle = LoadLibrary(dll); + if (handle != IntPtr.Zero) + continue; + var sysdir = new StringBuilder(String.Empty, 2048); + GetSystemDirectory(sysdir, (uint)sysdir.Capacity); + throw new OnnxRuntimeException( + ErrorCode.NoSuchFile, + $"kernel32.LoadLibrary():'{dll}' not found. MIGraphX are required for GPU execution. " + + $". Verify it is available in the system directory={sysdir}. Else copy it to the output folder." + ); + } + } + return true; + } + #endregion #region SafeHandle diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index f78e1c86fd65a..5e1770fe47bbd 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6440,6 +6440,88 @@ struct OrtApi { _In_reads_(num_tensors) OrtValue* const* dst_tensors, _In_opt_ OrtSyncStream* stream, _In_ size_t num_tensors); + + /// @} + /// \name OrtMIGraphXProviderOptions + /// @{ + + /** \brief Create an OrtMIGraphXProviderOptions + * + * \param[out] out Newly created ::OrtMIGraphXProviderOptions. Must be released with OrtApi::ReleaseMIGraphXProviderOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.xx. + */ + ORT_API2_STATUS(CreateMIGraphXProviderOptions, _Outptr_ OrtMIGraphXProviderOptions** out); + + /** \brief Set options in a MIGraphX Execution Provider. + * + * For example, key="device_id" and value="0" + * + * \param[in] migraphx_options + * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys + * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values + * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.xx. + */ + ORT_API2_STATUS(UpdateMIGraphXProviderOptions, _Inout_ OrtMIGraphXProviderOptions* migraphx_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /** + * Get serialized MIGraphX provider options string. + * + * For example, "device_id=0;;......" + * + * \param migraphx_options - OrtMIGraphXProviderOptions instance + * \param allocator - a ptr to an instance of OrtAllocator obtained with CreateAllocator() or GetAllocatorWithDefaultOptions() + * the specified allocator will be used to allocate continuous buffers for output strings and lengths. + * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.xx. + */ + ORT_API2_STATUS(GetMIGraphXProviderOptionsAsString, _In_ const OrtMIGraphXProviderOptions* migraphx_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); + + /** \brief Release an ::OrtMIGraphXProviderOptions + * + * \note This is an exception in the naming convention of other Release* functions, as the name of the method does not have the V2 suffix, but the type does + * + * \since Version 1.xx. + */ + void(ORT_API_CALL* ReleaseMIGraphXProviderOptions)(_Frees_ptr_opt_ OrtMIGraphXProviderOptions* input); + + /** + * Update MIGraphX EP provider option where its data type is pointer, for example 'user_compute_stream'. + * If the data type of the provider option can be represented by string please use UpdateMIGraphXProviderOptions. + * + * Note: It's caller's responsibility to properly manage the lifetime of the instance pointed by this pointer. + * + * \param migraphx_options - OrtMIGraphXProviderOptions instance + * \param key - Name of the provider option + * \param value - A pointer to the instance that will be assigned to this provider option + * + * \since Version 1.xx. + */ + ORT_API2_STATUS(UpdateMIGraphXProviderOptionsWithValue, _Inout_ OrtMIGraphXProviderOptions* migraphx_options, _In_ const char* key, _In_ void* value); + + /** + * Get MIGraphX EP provider option where its data type is pointer. + * If the data type of the provider option can be represented by string please use GetMIGraphXProviderOptionsAsString. + * + * \param migraphx_options - OrtMIGraphXProviderOptions instance + * \param key - Name of the provider option + * \param ptr - A pointer to the instance that is kept by the provider option + * + * \since Version 1.xx. + */ + ORT_API2_STATUS(GetMIGraphXProviderOptionsByName, _In_ const OrtMIGraphXProviderOptions* migraphx_options, _In_ const char* key, _Outptr_ void** ptr); }; /* diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index bb3f78d42815b..9566e4735bdb4 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -119,24 +119,36 @@ struct MIGraphX_Provider : Provider { migx_options.migraphx_int8_enable = internal_options.int8_enable; migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune; - char* dest = nullptr; - auto str_size = internal_options.int8_calibration_table_name.size(); - if (str_size == 0) { + if (internal_options.int8_calibration_table_name.empty()) { migx_options.migraphx_int8_calibration_table_name = nullptr; } else { - dest = new char[str_size + 1]; + auto str_size = internal_options.int8_calibration_table_name.size(); + auto dest = new char[str_size + 1]; #ifdef _MSC_VER strncpy_s(dest, str_size + 1, internal_options.int8_calibration_table_name.c_str(), str_size); #else strncpy(dest, internal_options.int8_calibration_table_name.c_str(), str_size); #endif dest[str_size] = '\0'; - migx_options.migraphx_int8_calibration_table_name = (const char*)dest; + migx_options.migraphx_int8_calibration_table_name = static_cast(dest); } migx_options.migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table; - migx_options.migraphx_cache_dir = internal_options.model_cache_dir.native().c_str(); + if (internal_options.model_cache_dir.empty()) { + migx_options.migraphx_cache_dir = nullptr; + } else { + const auto cache_dir_str{internal_options.model_cache_dir.native()}; + auto cache_dir = new ORTCHAR_T[cache_dir_str.size() + 1]; +#ifdef _MSC_VER + wcsncpy_s(cache_dir, cache_dir_str.size() + 1, cache_dir_str.data(), cache_dir_str.size()); +#else + strncpy(cache_dir, cache_dir_str.data(), cache_dir_str.size()); +#endif + cache_dir[cache_dir_str.size()] = '\0'; + migx_options.migraphx_cache_dir = cache_dir; + } + migx_options.migraphx_arena_extend_strategy = static_cast(internal_options.arena_extend_strategy); migx_options.migraphx_mem_limit = internal_options.mem_limit; } diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h b/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h index 02d30ad0f6fbb..db169b9e2f5a9 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h @@ -6,6 +6,7 @@ #include #include "core/providers/providers.h" +#include "core/framework/provider_options.h" struct OrtMIGraphXProviderOptions; @@ -14,5 +15,6 @@ namespace onnxruntime { struct MIGraphXProviderFactoryCreator { static std::shared_ptr Create(int device_id); static std::shared_ptr Create(const OrtMIGraphXProviderOptions* options); + static std::shared_ptr Create(const ProviderOptions&); }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 62ce5adc97b00..2088861cd69a6 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4073,7 +4073,13 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::ReleaseSyncStream, &OrtApis::CopyTensors, -}; + + &OrtApis::CreateMIGraphXProviderOptions, + &OrtApis::UpdateMIGraphXProviderOptions, + &OrtApis::GetMIGraphXProviderOptionsAsString, + &OrtApis::ReleaseMIGraphXProviderOptions, + &OrtApis::UpdateMIGraphXProviderOptionsWithValue, + &OrtApis::GetMIGraphXProviderOptionsByName}; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. static_assert(sizeof(OrtApiBase) == sizeof(void*) * 2, "New methods can't be added to OrtApiBase as it is not versioned"); diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index d2f22397bf82c..89cd22ac3c417 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -738,4 +738,15 @@ ORT_API_STATUS_IMPL(CopyTensors, _In_ const OrtEnv* env, _In_reads_(num_tensors) OrtValue* const* dst_tensors, _In_opt_ OrtSyncStream* stream, _In_ size_t num_tensors); + +ORT_API_STATUS_IMPL(CreateMIGraphXProviderOptions, _Outptr_ OrtMIGraphXProviderOptions** out); +ORT_API_STATUS_IMPL(UpdateMIGraphXProviderOptions, _Inout_ OrtMIGraphXProviderOptions* migraphx_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + size_t num_keys); +ORT_API_STATUS_IMPL(GetMIGraphXProviderOptionsAsString, _In_ const OrtMIGraphXProviderOptions* migraphx_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); +ORT_API(void, ReleaseMIGraphXProviderOptions, _Frees_ptr_opt_ OrtMIGraphXProviderOptions*); + +ORT_API_STATUS_IMPL(UpdateMIGraphXProviderOptionsWithValue, _Inout_ OrtMIGraphXProviderOptions* migraphx_options, _In_ const char* key, _In_ void* value); +ORT_API_STATUS_IMPL(GetMIGraphXProviderOptionsByName, _In_ const OrtMIGraphXProviderOptions* migraphx_options, _In_ const char* key, _Outptr_ void** ptr); } // namespace OrtApis diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index bd50cdfe6e066..9b756b72d6c6e 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -6,6 +6,7 @@ #include #include +#include #include "core/common/inlined_containers.h" #include "core/common/path_string.h" @@ -108,6 +109,7 @@ using EtwRegistrationManager_EtwInternalCallback = EtwRegistrationManager::EtwIn #include "core/providers/cann/cann_provider_factory.h" #include "core/providers/dnnl/dnnl_provider_factory.h" #include "core/providers/migraphx/migraphx_provider_factory.h" +#include "core/providers/migraphx/migraphx_execution_provider_info.h" #include "core/providers/openvino/openvino_provider_factory.h" #include "core/providers/tensorrt/tensorrt_provider_factory.h" #include "core/providers/tensorrt/tensorrt_provider_options.h" @@ -2110,6 +2112,12 @@ std::shared_ptr NvProviderFactoryCreator::Create( return nullptr; } +std::shared_ptr MIGraphXProviderFactoryCreator::Create(const ProviderOptions& provider_options) { + OrtMIGraphXProviderOptions migraphx_options; + s_library_migraphx.Get().UpdateProviderOptions(&migraphx_options, provider_options); + return s_library_migraphx.Get().CreateExecutionProviderFactory(&migraphx_options); +} + std::shared_ptr MIGraphXProviderFactoryCreator::Create(const OrtMIGraphXProviderOptions* provider_options) { return s_library_migraphx.Get().CreateExecutionProviderFactory(provider_options); } @@ -2644,7 +2652,8 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateTensorRTProviderOptions, #if defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE) || \ defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) || \ defined(USE_CANN) || \ - defined(USE_DNNL) + defined(USE_DNNL) || \ + defined(USE_MIGRAPHX) static std::string BuildOptionsString(const onnxruntime::ProviderOptions::iterator& begin, const onnxruntime::ProviderOptions::iterator& end) { std::ostringstream options; @@ -3079,3 +3088,125 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, _In_ return nullptr; API_IMPL_END } + +ORT_API_STATUS_IMPL(OrtApis::CreateMIGraphXProviderOptions, _Outptr_ OrtMIGraphXProviderOptions** out) { + API_IMPL_BEGIN +#ifdef USE_MIGRAPHX + auto migraphx_options = std::make_unique(); + memset(migraphx_options.get(), 0, sizeof(OrtMIGraphXProviderOptions)); + *out = migraphx_options.release(); + return nullptr; +#else + ORT_UNUSED_PARAMETER(out); + return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build."); +#endif + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::UpdateMIGraphXProviderOptions, + _Inout_ OrtMIGraphXProviderOptions* migraphx_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + size_t num_keys) { + API_IMPL_BEGIN +#ifdef USE_MIGRAPHX + onnxruntime::ProviderOptions provider_options_map; + for (size_t i = 0; i != num_keys; ++i) { + if (provider_options_keys[i] == nullptr || provider_options_keys[i][0] == '\0' || + provider_options_values[i] == nullptr || provider_options_values[i][0] == '\0') { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "key/value cannot be empty"); + } + + provider_options_map[provider_options_keys[i]] = provider_options_values[i]; + } + + onnxruntime::s_library_migraphx.Get().UpdateProviderOptions(reinterpret_cast(migraphx_options), provider_options_map); + return nullptr; +#else + ORT_UNUSED_PARAMETER(migraphx_options); + ORT_UNUSED_PARAMETER(provider_options_keys); + ORT_UNUSED_PARAMETER(provider_options_values); + ORT_UNUSED_PARAMETER(num_keys); + return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build."); +#endif + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::GetMIGraphXProviderOptionsAsString, + _In_ const OrtMIGraphXProviderOptions* migraphx_options, + _Inout_ OrtAllocator* allocator, + _Outptr_ char** ptr) { + API_IMPL_BEGIN +#ifdef USE_MIGRAPHX + onnxruntime::ProviderOptions options = + onnxruntime::s_library_migraphx.Get().GetProviderOptions(reinterpret_cast(migraphx_options)); + std::string options_str = BuildOptionsString(options.begin(), options.end()); + *ptr = onnxruntime::StrDup(options_str, allocator); + return nullptr; +#else + ORT_UNUSED_PARAMETER(migraphx_options); + ORT_UNUSED_PARAMETER(allocator); + ORT_UNUSED_PARAMETER(ptr); + return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build."); +#endif + API_IMPL_END +} + +ORT_API(void, OrtApis::ReleaseMIGraphXProviderOptions, _Frees_ptr_opt_ OrtMIGraphXProviderOptions* ptr) { +#ifdef USE_MIGRAPHX + std::unique_ptr p(ptr); + OrtAllocator* allocator; + GetAllocatorWithDefaultOptions(&allocator); + if (ptr->migraphx_cache_dir != nullptr) { + allocator->Free(allocator, const_cast(ptr->migraphx_cache_dir)); + } +#else + ORT_UNUSED_PARAMETER(ptr); +#endif +} + +ORT_API_STATUS_IMPL(OrtApis::UpdateMIGraphXProviderOptionsWithValue, + _Inout_ OrtMIGraphXProviderOptions* migraphx_options, + _In_ const char* key, + _In_ void* value) { + API_IMPL_BEGIN +#ifdef USE_MIGRAPHX + auto sv = std::string_view{key}; + OrtAllocator* allocator; + GetAllocatorWithDefaultOptions(&allocator); + if (sv == onnxruntime::migraphx_provider_option::kDeviceId) { + auto dv = std::string_view{static_cast(value)}; + if (std::from_chars(dv.data(), dv.data() + dv.length(), migraphx_options->device_id).ec == std::errc::invalid_argument) { + ORT_THROW("Cannot convert from string to integer - invalid argument"); + } + } else if (sv == onnxruntime::migraphx_provider_option::kModelCacheDir) { + auto sd = std::string_view{static_cast(value)}; + migraphx_options->migraphx_cache_dir = onnxruntime::StrDup(onnxruntime::ToPathString(sd), allocator); + } else { + ORT_THROW("Unsupported provider option name: '" + std::string{sv} + "'"); + } + return nullptr; +#else + ORT_UNUSED_PARAMETER(migraphx_options); + ORT_UNUSED_PARAMETER(key); + ORT_UNUSED_PARAMETER(value); + return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build."); +#endif + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::GetMIGraphXProviderOptionsByName, + _In_ const OrtMIGraphXProviderOptions* migraphx_options, + _In_ const char* key, + _Outptr_ void** ptr) { + API_IMPL_BEGIN +#ifdef USE_MIGRAPHX + return nullptr; +#else + ORT_UNUSED_PARAMETER(migraphx_options); + ORT_UNUSED_PARAMETER(key); + ORT_UNUSED_PARAMETER(ptr); + return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build."); +#endif + API_IMPL_END +} diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 18a463ef69943..a99f873ef0eb9 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -101,6 +101,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, VitisAI, CoreML, NvTensorRtRtx, // TensorRt EP for RTX GPUs. + MIGraphX }; struct EpToAppend { @@ -109,7 +110,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, const char* canonical_name = nullptr; }; - static std::array supported_eps = { + static std::array supported_eps = { EpToAppend{EpID::DML, "DML", kDmlExecutionProvider}, EpToAppend{EpID::QNN, "QNN", kQnnExecutionProvider}, EpToAppend{EpID::OpenVINO, "OpenVINO", kOpenVINOExecutionProvider}, @@ -121,7 +122,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, EpToAppend{EpID::JS, "JS", kJsExecutionProvider}, EpToAppend{EpID::VitisAI, "VitisAI", kVitisAIExecutionProvider}, EpToAppend{EpID::CoreML, "CoreML", kCoreMLExecutionProvider}, - EpToAppend{EpID::NvTensorRtRtx, "NvTensorRtRtx", kNvTensorRTRTXExecutionProvider}}; + EpToAppend{EpID::NvTensorRtRtx, "NvTensorRtRtx", kNvTensorRTRTXExecutionProvider}, + EpToAppend{EpID::MIGraphX, "MIGraphX", kMIGraphXExecutionProvider}}; ProviderOptions provider_options; OrtStatus* status = ParseProviderOptions(provider_options_keys, @@ -279,6 +281,14 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options, &(options->value))); #else status = create_not_supported_status(); +#endif + break; + } + case EpID::MIGraphX: { +#if defined(USE_MIGRAPHX) || defined(USE_MIGRAPHX_PROVIDER_INTERFACE) + options->provider_factories.push_back(MIGraphXProviderFactoryCreator::Create(provider_options)); +#else + status = create_not_supported_status(); #endif break; } @@ -617,6 +627,56 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, ORT_UNUSED_PARAMETER(num_keys); return CreateNotEnabledStatus("VitisAI"); } + +ORT_API_STATUS_IMPL(OrtApis::CreateMIGraphXProviderOptions, _Outptr_ OrtMIGraphXProviderOptions** out) { + ORT_UNUSED_PARAMETER(out); + return CreateNotEnabledStatus("MIGraphX"); +} + +ORT_API_STATUS_IMPL(OrtApis::UpdateMIGraphXProviderOptions, + _Inout_ OrtMIGraphXProviderOptions* migraphx_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + size_t num_keys) { + ORT_UNUSED_PARAMETER(migraphx_options); + ORT_UNUSED_PARAMETER(provider_options_keys); + ORT_UNUSED_PARAMETER(provider_options_values); + ORT_UNUSED_PARAMETER(num_keys); + return CreateNotEnabledStatus("MIGraphX"); +} + +ORT_API_STATUS_IMPL(OrtApis::GetMIGraphXProviderOptionsAsString, + _In_ const OrtMIGraphXProviderOptions* migraphx_options, _Inout_ OrtAllocator* allocator, + _Outptr_ char** ptr) { + ORT_UNUSED_PARAMETER(migraphx_options); + ORT_UNUSED_PARAMETER(allocator); + ORT_UNUSED_PARAMETER(ptr); + return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build."); +} + +ORT_API(void, OrtApis::ReleaseMIGraphXProviderOptions, _Frees_ptr_opt_ OrtMIGraphXProviderOptions* ptr) { + ORT_UNUSED_PARAMETER(ptr); +} + +ORT_API_STATUS_IMPL(OrtApis::UpdateMIGraphXProviderOptionsWithValue, + _Inout_ OrtMIGraphXProviderOptions* migraphx_options, + _In_ const char* key, + _In_ void* value) { + ORT_UNUSED_PARAMETER(migraphx_options); + ORT_UNUSED_PARAMETER(key); + ORT_UNUSED_PARAMETER(value); + return CreateNotEnabledStatus("MIGraphX"); +} + +ORT_API_STATUS_IMPL(OrtApis::GetMIGraphXProviderOptionsByName, + _In_ const OrtMIGraphXProviderOptions* migraphx_options, + _In_ const char* key, + _Outptr_ void** ptr) { + ORT_UNUSED_PARAMETER(migraphx_options); + ORT_UNUSED_PARAMETER(key); + ORT_UNUSED_PARAMETER(ptr); + return CreateNotEnabledStatus("MIGraphX"); +} #endif ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_ROCM, _In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* provider_options) { diff --git a/setup.py b/setup.py index 5ab1ac5b840d4..e49b31572f156 100644 --- a/setup.py +++ b/setup.py @@ -412,6 +412,7 @@ def finalize_options(self): libs.extend(["onnxruntime_providers_nv_tensorrt_rtx.dll"]) libs.extend(["onnxruntime_providers_openvino.dll"]) libs.extend(["onnxruntime_providers_cuda.dll"]) + libs.extend(["onnxruntime_providers_migraphx.dll"]) libs.extend(["onnxruntime_providers_vitisai.dll"]) libs.extend(["onnxruntime_providers_qnn.dll"]) # DirectML Libs diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 561a76be5fa89..e6a2d0741c859 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1994,6 +1994,7 @@ def build_nuget_package( use_winml, use_qnn, use_dml, + use_migraphx, enable_training_apis, msbuild_extra_options, ): @@ -2031,6 +2032,9 @@ def build_nuget_package( elif use_tensorrt: execution_provider = "/p:ExecutionProvider=tensorrt" package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.TensorRT" + elif use_migraphx: + execution_provider = "/p:ExecutionProvider=migraphx" + package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.MIGraphX" elif use_dnnl: execution_provider = "/p:ExecutionProvider=dnnl" package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.DNNL" @@ -2622,6 +2626,7 @@ def main(): getattr(args, "use_winml", False), args.use_qnn, getattr(args, "use_dml", False), + args.use_migraphx, args.enable_training_apis, normalize_arg_list(args.msbuild_extra_options), ) diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index 211cb7a2a8a75..36173c9489513 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -22,6 +22,8 @@ def get_package_name(os, cpu_arch, ep, is_training_package): pkg_name += "-tensorrt" elif ep == "rocm": pkg_name += "-rocm" + elif ep == "migraphx": + pkg_name += "-migraphx" elif os == "linux": pkg_name += "-linux-" pkg_name += cpu_arch @@ -31,6 +33,8 @@ def get_package_name(os, cpu_arch, ep, is_training_package): pkg_name += "-tensorrt" elif ep == "rocm": pkg_name += "-rocm" + elif ep == "migraphx": + pkg_name += "-migraphx" elif os == "osx": pkg_name = "onnxruntime-osx-" + cpu_arch return pkg_name @@ -44,7 +48,11 @@ def get_package_name(os, cpu_arch, ep, is_training_package): def is_this_file_needed(ep, filename, package_name): if package_name == "Microsoft.ML.OnnxRuntime.Gpu": return False - return (ep != "cuda" or "cuda" in filename) and (ep != "tensorrt" or "cuda" not in filename) + return ( + (ep != "cuda" or "cuda" in filename) + and (ep != "tensorrt" or "cuda" not in filename) + and (ep != "migraphx" or "migraphx" not in filename) + ) # nuget_artifacts_dir: the directory with uncompressed C API tarball/zip files @@ -138,7 +146,7 @@ def parse_arguments(): required=False, default="None", type=str, - choices=["cuda", "dnnl", "openvino", "tensorrt", "snpe", "qnn", "None"], + choices=["cuda", "dnnl", "openvino", "migraphx", "tensorrt", "snpe", "qnn", "None"], help="The selected execution provider for this build.", ) parser.add_argument("--sdk_info", required=False, default="", type=str, help="dependency SDK information.") @@ -182,6 +190,8 @@ def generate_description(line_list, package_name): description = "This package contains Linux native shared library artifacts for ONNX Runtime with CUDA." elif "Microsoft.ML.OnnxRuntime.Gpu.Windows" in package_name: description = "This package contains Windows native shared library artifacts for ONNX Runtime with CUDA." + elif "Microsoft.ML.OnnxRuntime.MIGraphX" in package_name: + description = "This package contains native shared library artifacts for ONNX Runtime with MIGraphX." elif "Intel.ML.OnnxRuntime" in package_name: description = "This package contains native shared library artifacts for ONNX Runtime with OpenVINO." elif "Microsoft.ML.OnnxRuntime" in package_name: # This is a Microsoft.ML.OnnxRuntime.* package @@ -359,6 +369,7 @@ def generate_files(line_list, args): is_windowsai_package = args.package_name == "Microsoft.AI.MachineLearning" is_snpe_package = args.package_name == "Microsoft.ML.OnnxRuntime.Snpe" is_qnn_package = args.package_name == "Microsoft.ML.OnnxRuntime.QNN" + is_migraphx_package = args.package_name == "Microsoft.ML.OnnxRuntime.MIGraphX" is_training_package = args.package_name in [ "Microsoft.ML.OnnxRuntime.Training", "Microsoft.ML.OnnxRuntime.Training.Gpu", @@ -384,6 +395,7 @@ def generate_files(line_list, args): "openvino_ep_shared_lib": "onnxruntime_providers_openvino.dll", "cuda_ep_shared_lib": "onnxruntime_providers_cuda.dll", "qnn_ep_shared_lib": "onnxruntime_providers_qnn.dll", + "migraphx_ep_shared_lib": "onnxruntime_providers_migraphx.dll", "onnxruntime_perf_test": "onnxruntime_perf_test.exe", "onnx_test_runner": "onnx_test_runner.exe", } @@ -402,6 +414,7 @@ def generate_files(line_list, args): "openvino_ep_shared_lib": "libonnxruntime_providers_openvino.so", "cuda_ep_shared_lib": "libonnxruntime_providers_cuda.so", "rocm_ep_shared_lib": "libonnxruntime_providers_rocm.so", + "migraphx_ep_shared_lib": "libonnxruntime_providers_migraphx.so", "onnxruntime_perf_test": "onnxruntime_perf_test", "onnx_test_runner": "onnx_test_runner", } @@ -421,7 +434,7 @@ def generate_files(line_list, args): include_dir = f"{build_dir}\\native\\include" # Sub.Gpu packages do not include the onnxruntime headers - if args.package_name != "Microsoft.ML.OnnxRuntime.Gpu": + if args.package_name != "Microsoft.ML.OnnxRuntime.Gpu" and args.package_name != "Microsoft.ML.OnnxRuntime.MIGraphX": files_list.append( "' ) + if args.execution_provider == "migraphx": + files_list.append( + "' + ) + files_list.append( + "' + ) + if is_dml_package: files_list.append( "' ) - # process all other library dependencies - if is_cpu_package or is_cuda_gpu_package or is_dml_package or is_mklml_package: + # process all other library dependencies + if is_cpu_package or is_cuda_gpu_package or is_migraphx_package or is_dml_package or is_mklml_package: # Process dnnl dependency if os.path.exists(os.path.join(args.native_build_path, nuget_dependencies["dnnl"])): files_list.append( @@ -899,6 +932,7 @@ def generate_files(line_list, args): or is_cuda_gpu_linux_sub_package or is_cuda_gpu_win_sub_package or is_rocm_gpu_package + or is_migraphx_package or is_dml_package or is_mklml_package or is_snpe_package @@ -1124,6 +1158,7 @@ def validate_execution_provider(execution_provider): or execution_provider == "tensorrt" or execution_provider == "openvino" or execution_provider == "rocm" + or execution_provider == "migraphx" ): raise Exception( "On Linux platform nuget generation is supported only " From 453ebf5f9fbd3b57029d97297048f2ef47029f7b Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Mon, 28 Jul 2025 19:04:40 +0200 Subject: [PATCH 20/46] Trim the output names of non-digits --- .../migraphx/migraphx_execution_provider.cc | 14 +++++------ .../migraphx_execution_provider_utils.h | 24 +++++++++++++++++++ 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 591caafd9657d..e89ab17d739a4 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1628,17 +1628,17 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& m.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); } - // It is a output argument + // It is an output argument else { - auto compute_output_index = [](const std::string& name) -> int { - std::string out_name_prefix = "#output_"; - auto pos = name.find(out_name_prefix); - if (pos == std::string::npos) { + auto compute_output_index = [](const std::string_view sv) -> int { + constexpr std::string_view out_name_prefix = "#output_"; + const auto pos = sv.find(out_name_prefix); + if (pos == std::string_view::npos) { return -1; } - std::string index_str = name.substr(pos + out_name_prefix.length()); - return std::stoi(index_str); + const auto index_str = sv.substr(pos + out_name_prefix.length()); + return ToInteger(Trim(index_str, std::isdigit)); }; int output_index = compute_output_index(name); diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h index 6d239b0dd073c..9cc4564e6fc98 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h @@ -327,4 +327,28 @@ inline std::string GenerateGraphId(const GraphViewer& graph_viewer) { return std::string{s.data(), ptr}; } +inline std::string_view TrimLeft(std::string_view sv, int (*fn)(int) = std::isspace) { + return sv.substr(0, sv.end() - std::find_if(sv.begin(), sv.end(), [fn](int ch) { + return fn(ch); + })); +} + +inline std::string_view TrimRight(std::string_view sv, int (*fn)(int) = std::isspace) { + return sv.substr(sv.end() - std::find_if(sv.rbegin(), sv.rend(), [fn](int ch) { + return fn(ch); + }).base()); +} + +inline std::string_view Trim(std::string_view sv, int (*fn)(int) = std::isspace) { + return TrimRight(TrimLeft(sv, fn), fn); +} + +inline int ToInteger(const std::string_view sv) { + int result = 0; + if (auto [_, ec] = std::from_chars(sv.data(), sv.data() + sv.length(), result); ec == std::errc()) { + return result; + } + ORT_THROW("invalid input for conversion to integer"); +} + } // namespace onnxruntime From 71f272990305181aa24d7531748c46abb09e3551 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Tue, 29 Jul 2025 09:50:45 +0200 Subject: [PATCH 21/46] Use member variables for EP configuration instead of the info member variable --- .../core/session/onnxruntime_c_api.h | 3 + .../migraphx/migraphx_execution_provider.cc | 36 +++--- .../migraphx/migraphx_execution_provider.h | 30 ++++- .../migraphx_execution_provider_info.cc | 104 ++++++++---------- .../migraphx_execution_provider_info.h | 59 +++++----- .../migraphx_execution_provider_utils.h | 2 +- .../migraphx/migraphx_provider_factory.cc | 62 ++++------- .../migraphx/migraphx_provider_factory.h | 10 +- .../python/onnxruntime_pybind_mlvalue.cc | 3 +- .../python/onnxruntime_pybind_state.cc | 3 + .../python/onnxruntime_pybind_state_common.cc | 6 +- .../python/onnxruntime_pybind_state_common.h | 11 +- onnxruntime/test/util/default_providers.cc | 5 +- .../nuget/generate_nuspec_for_native_nuget.py | 2 +- 14 files changed, 168 insertions(+), 168 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 5e1770fe47bbd..99393d7c8d719 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -775,6 +775,9 @@ typedef struct OrtMIGraphXProviderOptions { int migraphx_arena_extend_strategy; int migraphx_bf16_enable; // MIGraphX BF16 precision. Default 0 = false, nonzero = true const ORTCHAR_T* migraphx_cache_dir; // MIGraphX model cache directory + void* migraphx_external_alloc; // Pointer to an external Alloc() function (default is none) + void* migraphx_external_free; // Pointer to an external Free() function (default is none) + void* migraphx_external_empty_cache; // Pointer to an external EmptyCache() function (default is none) } OrtMIGraphXProviderOptions; /** \brief OpenVINO Provider Options diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index e89ab17d739a4..20f419db45565 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -117,7 +117,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, info.device_id)}, - info_(info) { + device_id_{info.device_id} { InitProviderOrtApi(); get_flags_from_session_info(info); metadef_id_generator_ = ModelMetadefIdGenerator::Create(); @@ -129,7 +129,7 @@ MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info) { // Set GPU device to be used - HIP_CALL_THROW(hipSetDevice(info_.device_id)); + HIP_CALL_THROW(hipSetDevice(device_id_)); HIP_CALL_THROW(hipGetDeviceProperties(&device_prop_, info.device_id)); t_ = migraphx::target(info.target_device.c_str()); @@ -245,7 +245,7 @@ void MIGraphXExecutionProvider::get_flags_from_env() { const std::string cache_path = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kCachePath); if (!cache_path.empty()) { calibration_cache_path_ = cache_path; - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CACHE_PATH: " << calibration_cache_path_; + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CACHE_PATH: " << calibration_cache_path_.string(); } const std::string int8_use_native_migraphx_calibration_table_env = @@ -276,7 +276,7 @@ void MIGraphXExecutionProvider::get_flags_from_env() { if (!model_cache_path_env.empty()) { model_cache_path_ = GetEnvironmentVar(migraphx_env_vars::kModelCachePath); LOGS_DEFAULT(INFO) << "\n" - << migraphx_env_vars::kModelCachePath << ": " << model_cache_path_; + << migraphx_env_vars::kModelCachePath << ": " << model_cache_path_.string(); } // dump unsupported ops @@ -295,7 +295,7 @@ void MIGraphXExecutionProvider::get_flags_from_env() { } void MIGraphXExecutionProvider::print_migraphx_ep_flags() { - LOGS_DEFAULT(VERBOSE) << "\n " << migraphx_provider_option::kDeviceId << ": " << info_.device_id + LOGS_DEFAULT(VERBOSE) << "\n " << migraphx_provider_option::kDeviceId << ": " << device_id_ << "\n " << migraphx_provider_option::kFp16Enable << ": " << fp16_enable_ << "\n " << migraphx_provider_option::kBf16Enable << ": " << bf16_enable_ << "\n " << migraphx_provider_option::kFp8Enable << ": " << fp8_enable_ @@ -311,16 +311,14 @@ void MIGraphXExecutionProvider::print_migraphx_ep_flags() { AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t migx_mem_limit, ArenaExtendStrategy arena_extend_strategy, - MIGraphXExecutionProviderExternalAllocatorInfo - external_allocator_info, + void* alloc_fn, + void* free_fn, + void* empty_cache_fn, const OrtArenaCfg* default_memory_arena_cfg) { - if (external_allocator_info.UseExternalAllocator()) { - AllocatorCreationInfo default_memory_info( - [external_allocator_info](OrtDevice::DeviceId id) { - return std::make_unique(id, HIP, - external_allocator_info.alloc, - external_allocator_info.free, - external_allocator_info.empty_cache); + if (alloc_fn != nullptr && free_fn != nullptr) { + const AllocatorCreationInfo default_memory_info( + [alloc_fn, free_fn, empty_cache_fn](OrtDevice::DeviceId id) { + return std::make_unique(id, HIP, alloc_fn, free_fn, empty_cache_fn); }, device_id, false); @@ -346,13 +344,15 @@ AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::Devic std::vector MIGraphXExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId device_id) { return std::make_unique(device_id, onnxruntime::CUDA); }, - info_.device_id); + [](OrtDevice::DeviceId device_id) { + return std::make_unique(device_id, onnxruntime::CUDA); + }, + device_id_); AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { - return std::make_unique(device_id, onnxruntime::CUDA_PINNED); + return std::make_unique(device_id, CUDA_PINNED); }, - info_.device_id); + device_id_); return std::vector{CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)}; } diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index d9a95041ce225..189cee74f264e 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -14,6 +14,7 @@ #include #include "core/framework/arena_extend_strategy.h" #include "core/framework/execution_provider.h" +#include "core/framework/provider_options_utils.h" #include "core/providers/migraphx/migraphx_execution_provider_info.h" #include "core/providers/migraphx/migraphx_call.h" @@ -89,21 +90,33 @@ class MIGraphXExecutionProvider : public IExecutionProvider { std::shared_ptr GetKernelRegistry() const override; std::unique_ptr GetDataTransfer() const override; - static AllocatorPtr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t migx_mem_limit, ArenaExtendStrategy arena_extend_strategy, - MIGraphXExecutionProviderExternalAllocatorInfo external_alloc_info, const OrtArenaCfg* arena_cfg); + static AllocatorPtr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t mem_limit, ArenaExtendStrategy arena_extend_strategy, + void* alloc_fn, void* free_fn, void* empty_cache_fn, const OrtArenaCfg* arena_cfg); std::unique_ptr GetSubGraph(const std::vector& graph_nodes_index, const GraphViewer& graph) const; void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; std::vector CreatePreferredAllocators() override; - int GetDeviceId() const override { return info_.device_id; } + int GetDeviceId() const override { return device_id_; } ProviderOptions GetProviderOptions() const override { - return MIGraphXExecutionProviderInfo::ToProviderOptions(info_); + return { + {std::string{migraphx_provider_option::kDeviceId}, MakeStringWithClassicLocale(device_id_)}, + {std::string{migraphx_provider_option::kFp16Enable}, MakeStringWithClassicLocale(fp16_enable_)}, + {std::string{migraphx_provider_option::kBf16Enable}, MakeStringWithClassicLocale(bf16_enable_)}, + {std::string{migraphx_provider_option::kFp8Enable}, MakeStringWithClassicLocale(fp8_enable_)}, + {std::string{migraphx_provider_option::kInt8Enable}, MakeStringWithClassicLocale(int8_enable_)}, + {std::string{migraphx_provider_option::kModelCacheDir}, MakeStringWithClassicLocale(model_cache_path_)}, + {std::string{migraphx_provider_option::kMemLimit}, MakeStringWithClassicLocale(mem_limit_)}, + {std::string{migraphx_provider_option::kArenaExtendStrategy}, EnumToName(arena_extend_strategy_mapping, arena_extend_strategy_)}, + {std::string{migraphx_provider_option::kExhaustiveTune}, MakeStringWithClassicLocale(exhaustive_tune_)}, + {std::string{migraphx_provider_option::kGpuExternalAlloc}, MakeStringWithClassicLocale(external_alloc_)}, + {std::string{migraphx_provider_option::kGpuExternalFree}, MakeStringWithClassicLocale(external_free_)}, + {std::string{migraphx_provider_option::kGpuExternalEmptyCache}, MakeStringWithClassicLocale(external_empty_cache_)}}; } private: - MIGraphXExecutionProviderInfo info_; + OrtDevice::DeviceId device_id_{0}; bool fp16_enable_ = false; bool bf16_enable_ = false; bool fp8_enable_ = false; @@ -121,7 +134,9 @@ class MIGraphXExecutionProvider : public IExecutionProvider { hipStream_t stream_ = nullptr; hipDeviceProp_t device_prop_; bool exhaustive_tune_ = false; - mutable std::filesystem::path model_path_; + mutable std::filesystem::path model_path_{}; + size_t mem_limit_{std::numeric_limits::max()}; + ArenaExtendStrategy arena_extend_strategy_{ArenaExtendStrategy::kNextPowerOfTwo}; std::unordered_map map_progs_; std::unordered_map map_onnx_string_; @@ -130,6 +145,9 @@ class MIGraphXExecutionProvider : public IExecutionProvider { AllocatorPtr allocator_; std::unique_ptr metadef_id_generator_; + void* external_alloc_{nullptr}; + void* external_free_{nullptr}; + void* external_empty_cache_{nullptr}; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index 5bc2659f09636..77a0d8014b678 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -8,7 +8,6 @@ #include "core/common/make_string.h" #include "core/common/parse_string.h" -#include "core/framework/provider_options_utils.h" #include "core/providers/migraphx/migraphx_inc.h" #include "core/providers/migraphx/migraphx_call.h" @@ -19,95 +18,86 @@ const EnumNameMapping arena_extend_strategy_mapping{ {ArenaExtendStrategy::kSameAsRequested, "kSameAsRequested"}, }; -MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { - MIGraphXExecutionProviderInfo info{}; - void* alloc = nullptr; - void* free = nullptr; - void* empty_cache = nullptr; +MIGraphXExecutionProviderInfo::MIGraphXExecutionProviderInfo(const ProviderOptions& options) { ORT_THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( migraphx_provider_option::kDeviceId, - [&info](const std::string& value_str) -> Status { - ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id)); + [this](const std::string& value_str) -> Status { + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, device_id)); int num_devices{}; ORT_RETURN_IF_ERROR(HIP_CALL(hipGetDeviceCount(&num_devices))); ORT_RETURN_IF_NOT( - 0 <= info.device_id && info.device_id < num_devices, - "Invalid device ID: ", info.device_id, + 0 <= device_id && device_id < num_devices, + "Invalid device ID: ", device_id, ", must be between 0 (inclusive) and ", num_devices, " (exclusive)."); return Status::OK(); }) .AddValueParser( migraphx_provider_option::kGpuExternalAlloc, - [&alloc](const std::string& value_str) -> Status { + [this](const std::string& value_str) -> Status { std::uintptr_t address; ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); - alloc = reinterpret_cast(address); + external_alloc = reinterpret_cast(address); return Status::OK(); }) .AddValueParser( migraphx_provider_option::kGpuExternalFree, - [&free](const std::string& value_str) -> Status { + [this](const std::string& value_str) -> Status { std::uintptr_t address; ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); - free = reinterpret_cast(address); + external_free = reinterpret_cast(address); return Status::OK(); }) .AddValueParser( migraphx_provider_option::kGpuExternalEmptyCache, - [&empty_cache](const std::string& value_str) -> Status { + [this](const std::string& value_str) -> Status { std::uintptr_t address; ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); - empty_cache = reinterpret_cast(address); + external_empty_cache = reinterpret_cast(address); return Status::OK(); }) - .AddAssignmentToReference(migraphx_provider_option::kFp16Enable, info.fp16_enable) - .AddAssignmentToReference(migraphx_provider_option::kBf16Enable, info.bf16_enable) - .AddAssignmentToReference(migraphx_provider_option::kFp8Enable, info.fp8_enable) - .AddAssignmentToReference(migraphx_provider_option::kInt8Enable, info.int8_enable) - .AddAssignmentToReference(migraphx_provider_option::kModelCacheDir, info.model_cache_dir) - .AddAssignmentToReference(migraphx_provider_option::kExhaustiveTune, info.exhaustive_tune) - .AddAssignmentToReference(migraphx_provider_option::kMemLimit, info.mem_limit) - .AddAssignmentToEnumReference(migraphx_provider_option::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy) + .AddAssignmentToReference(migraphx_provider_option::kFp16Enable, fp16_enable) + .AddAssignmentToReference(migraphx_provider_option::kBf16Enable, bf16_enable) + .AddAssignmentToReference(migraphx_provider_option::kFp8Enable, fp8_enable) + .AddAssignmentToReference(migraphx_provider_option::kInt8Enable, int8_enable) + .AddAssignmentToReference(migraphx_provider_option::kModelCacheDir, model_cache_dir) + .AddAssignmentToReference(migraphx_provider_option::kExhaustiveTune, exhaustive_tune) + .AddAssignmentToReference(migraphx_provider_option::kMemLimit, mem_limit) + .AddAssignmentToEnumReference(migraphx_provider_option::kArenaExtendStrategy, arena_extend_strategy_mapping, arena_extend_strategy) .Parse(options)); - - MIGraphXExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache}; - info.external_allocator_info = alloc_info; - - return info; } -ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXExecutionProviderInfo& info) { - const ProviderOptions options{ - {std::string{migraphx_provider_option::kDeviceId}, MakeStringWithClassicLocale(info.device_id)}, - {std::string{migraphx_provider_option::kFp16Enable}, MakeStringWithClassicLocale(info.fp16_enable)}, - {std::string{migraphx_provider_option::kBf16Enable}, MakeStringWithClassicLocale(info.bf16_enable)}, - {std::string{migraphx_provider_option::kFp8Enable}, MakeStringWithClassicLocale(info.fp8_enable)}, - {std::string{migraphx_provider_option::kInt8Enable}, MakeStringWithClassicLocale(info.int8_enable)}, - {std::string{migraphx_provider_option::kMemLimit}, MakeStringWithClassicLocale(info.mem_limit)}, - {std::string{migraphx_provider_option::kGpuExternalAlloc}, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.alloc))}, - {std::string{migraphx_provider_option::kGpuExternalFree}, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.free))}, - {std::string{migraphx_provider_option::kGpuExternalEmptyCache}, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.empty_cache))}, - {std::string{migraphx_provider_option::kArenaExtendStrategy}, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, - {std::string{migraphx_provider_option::kExhaustiveTune}, MakeStringWithClassicLocale(info.exhaustive_tune)}, - {std::string{migraphx_provider_option::kModelCacheDir}, MakeStringWithClassicLocale(info.model_cache_dir)}, - }; - return options; +MIGraphXExecutionProviderInfo::MIGraphXExecutionProviderInfo(const OrtMIGraphXProviderOptions& options) noexcept + : device_id{static_cast(options.device_id)}, + fp16_enable{options.migraphx_fp16_enable != 0}, + bf16_enable{options.migraphx_bf16_enable != 0}, + fp8_enable{options.migraphx_fp8_enable != 0}, + int8_enable{options.migraphx_int8_enable != 0}, + model_cache_dir{options.migraphx_cache_dir}, + exhaustive_tune{options.migraphx_exhaustive_tune != 0}, + mem_limit{options.migraphx_mem_limit}, + arena_extend_strategy{options.migraphx_arena_extend_strategy}, + external_alloc{options.migraphx_external_alloc}, + external_free{options.migraphx_external_free}, + external_empty_cache{options.migraphx_external_empty_cache} { } -ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGraphXProviderOptions& info) { - const ProviderOptions options{ - {std::string{migraphx_provider_option::kDeviceId}, MakeStringWithClassicLocale(info.device_id)}, - {std::string{migraphx_provider_option::kFp16Enable}, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, - {std::string{migraphx_provider_option::kBf16Enable}, MakeStringWithClassicLocale(info.migraphx_bf16_enable)}, - {std::string{migraphx_provider_option::kFp8Enable}, MakeStringWithClassicLocale(info.migraphx_fp8_enable)}, - {std::string{migraphx_provider_option::kInt8Enable}, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, - {std::string{migraphx_provider_option::kMemLimit}, MakeStringWithClassicLocale(info.migraphx_mem_limit)}, - {std::string{migraphx_provider_option::kArenaExtendStrategy}, EnumToName(arena_extend_strategy_mapping, static_cast(info.migraphx_arena_extend_strategy))}, - {std::string{migraphx_provider_option::kExhaustiveTune}, MakeStringWithClassicLocale(info.migraphx_exhaustive_tune)}, - {std::string{migraphx_provider_option::kModelCacheDir}, MakeStringWithClassicLocale(info.migraphx_cache_dir)}, +ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions() const { + return { + {std::string{migraphx_provider_option::kDeviceId}, MakeStringWithClassicLocale(device_id)}, + {std::string{migraphx_provider_option::kFp16Enable}, MakeStringWithClassicLocale(fp16_enable)}, + {std::string{migraphx_provider_option::kBf16Enable}, MakeStringWithClassicLocale(bf16_enable)}, + {std::string{migraphx_provider_option::kFp8Enable}, MakeStringWithClassicLocale(fp8_enable)}, + {std::string{migraphx_provider_option::kInt8Enable}, MakeStringWithClassicLocale(int8_enable)}, + {std::string{migraphx_provider_option::kMemLimit}, MakeStringWithClassicLocale(mem_limit)}, + {std::string{migraphx_provider_option::kArenaExtendStrategy}, EnumToName(arena_extend_strategy_mapping, arena_extend_strategy)}, + {std::string{migraphx_provider_option::kExhaustiveTune}, MakeStringWithClassicLocale(exhaustive_tune)}, + {std::string{migraphx_provider_option::kGpuExternalAlloc}, MakeStringWithClassicLocale(external_alloc)}, + {std::string{migraphx_provider_option::kGpuExternalFree}, MakeStringWithClassicLocale(external_free)}, + {std::string{migraphx_provider_option::kGpuExternalEmptyCache}, MakeStringWithClassicLocale(external_empty_cache)}, + {std::string{migraphx_provider_option::kModelCacheDir}, MakeStringWithClassicLocale(model_cache_dir)}, }; - return options; } + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index 2b7547cbd3c4e..f08201a3aff06 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -11,6 +12,7 @@ #include "core/common/hash_combine.h" #include "core/framework/arena_extend_strategy.h" #include "core/framework/provider_options.h" +#include "core/framework/provider_options_utils.h" #include "core/session/onnxruntime_c_api.h" using namespace std::literals::string_view_literals; @@ -38,57 +40,46 @@ constexpr auto kGpuExternalEmptyCache = "migraphx_external_empty_cache"sv; constexpr auto kModelCacheDir = "migraphx_model_cache_dir"sv; } // namespace migraphx_provider_option -// Information needed to construct MIGraphX execution providers. -struct MIGraphXExecutionProviderExternalAllocatorInfo { - void* alloc{nullptr}; - void* free{nullptr}; - void* empty_cache{nullptr}; - - MIGraphXExecutionProviderExternalAllocatorInfo() { - alloc = nullptr; - free = nullptr; - empty_cache = nullptr; - } - - MIGraphXExecutionProviderExternalAllocatorInfo(void* a, void* f, void* e) { - alloc = a; - free = f; - empty_cache = e; - } - - bool UseExternalAllocator() const { - return (alloc != nullptr) && (free != nullptr); - } -}; +extern const EnumNameMapping arena_extend_strategy_mapping; // Information needed to construct trt execution providers. struct MIGraphXExecutionProviderInfo { - std::string target_device; + std::string target_device{"gpu"}; OrtDevice::DeviceId device_id{0}; bool fp16_enable{false}; bool bf16_enable{false}; bool fp8_enable{false}; bool int8_enable{false}; - std::string int8_calibration_table_name{""}; + std::string int8_calibration_table_name{}; bool int8_use_native_calibration_table{false}; std::filesystem::path model_cache_dir{}; bool exhaustive_tune{false}; - size_t mem_limit{std::numeric_limits::max()}; // Will be over-ridden by contents of `default_memory_arena_cfg` (if specified) - ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; // Will be over-ridden by contents of `default_memory_arena_cfg` (if specified) + size_t mem_limit{std::numeric_limits::max()}; + ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; OrtArenaCfg* default_memory_arena_cfg{nullptr}; - MIGraphXExecutionProviderExternalAllocatorInfo external_allocator_info{}; - static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); - static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info); - static ProviderOptions ToProviderOptions(const OrtMIGraphXProviderOptions& info); + void* external_alloc{nullptr}; + void* external_free{nullptr}; + void* external_empty_cache{nullptr}; + + bool UseExternalAlloc() const { + return external_alloc != nullptr && external_free != nullptr; + } + + MIGraphXExecutionProviderInfo() = default; + + explicit MIGraphXExecutionProviderInfo(const ProviderOptions& options); + explicit MIGraphXExecutionProviderInfo(const OrtMIGraphXProviderOptions& options) noexcept; + ProviderOptions ToProviderOptions() const; }; + } // namespace onnxruntime template <> struct std::hash<::onnxruntime::MIGraphXExecutionProviderInfo> { - size_t operator()(const ::onnxruntime::MIGraphXExecutionProviderInfo& info) const { + size_t operator()(const ::onnxruntime::MIGraphXExecutionProviderInfo& info) const noexcept { size_t value{0xbc9f1d34}; // seed // Bits: device_id (16), arena_extend_strategy (reserved 2), boolean options (1 each) @@ -109,9 +100,9 @@ struct std::hash<::onnxruntime::MIGraphXExecutionProviderInfo> { onnxruntime::HashCombine(info.mem_limit, value); // Memory pointers - onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.alloc), value); - onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.free), value); - onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.empty_cache), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_alloc), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_free), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_empty_cache), value); // The default memory arena cfg is not used in hashing right now. return value; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h index 9cc4564e6fc98..cce90f3ef82be 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h @@ -322,7 +322,7 @@ inline std::string GenerateGraphId(const GraphViewer& graph_viewer) { model_hash = hash[0] | static_cast(hash[1]) << 32; - std::array s; + std::array s{}; auto [ptr, ec] = std::to_chars(s.data(), s.data() + s.size(), model_hash, 16); return std::string{s.data(), ptr}; } diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 9566e4735bdb4..067e27caf7229 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -17,7 +17,6 @@ #include "core/providers/migraphx/migraphx_provider_factory.h" #include "core/providers/migraphx/migraphx_execution_provider.h" #include "core/providers/migraphx/migraphx_execution_provider_info.h" -#include "core/providers/migraphx/migraphx_provider_factory_creator.h" #include "core/providers/migraphx/migraphx_allocator.h" #include "core/providers/migraphx/gpu_data_transfer.h" #include "core/framework/provider_options.h" @@ -70,8 +69,9 @@ struct ProviderInfo_MIGraphX_Impl final : ProviderInfo_MIGraphX { HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyDeviceToHost)); } - std::shared_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override { - return MIGraphXExecutionProvider::CreateMIGraphXAllocator(device_id, migx_mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg); + std::shared_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, + void* alloc_fn, void* free_fn, void* empty_cache_fn, const OrtArenaCfg* default_memory_arena_cfg) override { + return MIGraphXExecutionProvider::CreateMIGraphXAllocator(device_id, mem_limit, arena_extend_strategy, alloc_fn, free_fn, empty_cache_fn, default_memory_arena_cfg); } } g_info; @@ -86,7 +86,7 @@ struct MIGraphX_Provider : Provider { } std::shared_ptr CreateExecutionProviderFactory(const void* provider_options) override { - auto& options = *reinterpret_cast(provider_options); + auto& options = *static_cast(provider_options); MIGraphXExecutionProviderInfo info; info.device_id = static_cast(options.device_id); info.target_device = "gpu"; @@ -110,17 +110,17 @@ struct MIGraphX_Provider : Provider { } void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override { - auto internal_options = onnxruntime::MIGraphXExecutionProviderInfo::FromProviderOptions(options); - auto& migx_options = *reinterpret_cast(provider_options); - migx_options.device_id = internal_options.device_id; - migx_options.migraphx_fp16_enable = internal_options.fp16_enable; - migx_options.migraphx_bf16_enable = internal_options.bf16_enable; - migx_options.migraphx_fp8_enable = internal_options.fp8_enable; - migx_options.migraphx_int8_enable = internal_options.int8_enable; - migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune; + MIGraphXExecutionProviderInfo internal_options{options}; + const auto migx_options = static_cast(provider_options); + migx_options->device_id = internal_options.device_id; + migx_options->migraphx_fp16_enable = internal_options.fp16_enable; + migx_options->migraphx_bf16_enable = internal_options.bf16_enable; + migx_options->migraphx_fp8_enable = internal_options.fp8_enable; + migx_options->migraphx_int8_enable = internal_options.int8_enable; + migx_options->migraphx_exhaustive_tune = internal_options.exhaustive_tune; if (internal_options.int8_calibration_table_name.empty()) { - migx_options.migraphx_int8_calibration_table_name = nullptr; + migx_options->migraphx_int8_calibration_table_name = nullptr; } else { auto str_size = internal_options.int8_calibration_table_name.size(); auto dest = new char[str_size + 1]; @@ -130,13 +130,13 @@ struct MIGraphX_Provider : Provider { strncpy(dest, internal_options.int8_calibration_table_name.c_str(), str_size); #endif dest[str_size] = '\0'; - migx_options.migraphx_int8_calibration_table_name = static_cast(dest); + migx_options->migraphx_int8_calibration_table_name = static_cast(dest); } - migx_options.migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table; + migx_options->migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table; if (internal_options.model_cache_dir.empty()) { - migx_options.migraphx_cache_dir = nullptr; + migx_options->migraphx_cache_dir = nullptr; } else { const auto cache_dir_str{internal_options.model_cache_dir.native()}; auto cache_dir = new ORTCHAR_T[cache_dir_str.size() + 1]; @@ -146,33 +146,19 @@ struct MIGraphX_Provider : Provider { strncpy(cache_dir, cache_dir_str.data(), cache_dir_str.size()); #endif cache_dir[cache_dir_str.size()] = '\0'; - migx_options.migraphx_cache_dir = cache_dir; + migx_options->migraphx_cache_dir = cache_dir; } - migx_options.migraphx_arena_extend_strategy = static_cast(internal_options.arena_extend_strategy); - migx_options.migraphx_mem_limit = internal_options.mem_limit; - } + migx_options->migraphx_arena_extend_strategy = static_cast(internal_options.arena_extend_strategy); + migx_options->migraphx_mem_limit = internal_options.mem_limit; - ProviderOptions GetProviderOptions(const void* provider_options) override { - auto& options = *reinterpret_cast(provider_options); - return onnxruntime::MIGraphXExecutionProviderInfo::ToProviderOptions(options); + migx_options->migraphx_external_alloc = internal_options.external_alloc; + migx_options->migraphx_external_free = internal_options.external_free; + migx_options->migraphx_external_empty_cache = internal_options.external_empty_cache; } - Status CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata*/, - size_t num_devices, - ProviderOptions& provider_options, - const OrtSessionOptions& session_options, - const OrtLogger& logger, - std::unique_ptr& ep) override { - ORT_UNUSED_PARAMETER(num_devices); - const ConfigOptions* config_options = &session_options.GetConfigOptions(); - - std::array configs_array = {&provider_options, config_options}; - auto ep_factory = CreateExecutionProviderFactory(&provider_options); - ep = ep_factory->CreateProvider(session_options, logger); - - return Status::OK(); + ProviderOptions GetProviderOptions(const void* provider_options) override { + return MIGraphXExecutionProviderInfo{*static_cast(provider_options)}.ToProviderOptions(); } void Initialize() override { diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h index 6baee291b7fe5..a19fa7a87fec1 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h @@ -5,23 +5,19 @@ #include +#include "core/framework/arena_extend_strategy.h" #include "core/framework/ortdevice.h" -#include "core/session/onnxruntime_c_api.h" namespace onnxruntime { class IAllocator; -class IDataTransfer; -struct IExecutionProviderFactory; -struct MIGraphXExecutionProviderInfo; -enum class ArenaExtendStrategy : int32_t; -struct MIGraphXExecutionProviderExternalAllocatorInfo; struct ProviderInfo_MIGraphX { virtual std::unique_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, const char* name) = 0; virtual std::unique_ptr CreateMIGraphXPinnedAllocator(OrtDevice::DeviceId device_id, const char* name) = 0; virtual void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, size_t count) = 0; virtual void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, size_t count) = 0; - virtual std::shared_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) = 0; + virtual std::shared_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t mem_limit, + ArenaExtendStrategy arena_extend_strategy, void* alloc_fn, void* free_fn, void* empty_cache_fn, const OrtArenaCfg* default_memory_arena_cfg) = 0; protected: ~ProviderInfo_MIGraphX() = default; // Can only be destroyed through a subclass instance diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 431fb0f422b81..2f49634323d1a 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -231,7 +231,8 @@ AllocatorPtr GetMIGraphXAllocator(OrtDevice::DeviceId id) { if (id_to_allocator_map->find(id) == id_to_allocator_map->end()) { // TODO: Expose knobs so that users can set fields associated with OrtArenaCfg so that we can pass it to the following method - id_to_allocator_map->insert({id, GetProviderInfo_MIGraphX().CreateMIGraphXAllocator(id, gpu_mem_limit, arena_extend_strategy, migx_external_allocator_info, nullptr)}); + id_to_allocator_map->insert({id, GetProviderInfo_MIGraphX().CreateMIGraphXAllocator(id, gpu_mem_limit, arena_extend_strategy, + migraphx::external::alloc_fn, migraphx::external::free_fn, migraphx::external::empty_cache_fn, nullptr)}); } return (*id_to_allocator_map)[id]; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index affbb3b79d1ef..5fea54a452260 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -972,6 +972,9 @@ static std::shared_ptr CreateExecutionProviderFactory SIZE_MAX, 0, 0, + nullptr, + nullptr, + nullptr, nullptr}; for (auto option : it->second) { if (option.first == "device_id") { diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.cc b/onnxruntime/python/onnxruntime_pybind_state_common.cc index 4b9e012764885..cccdb9d23900a 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.cc +++ b/onnxruntime/python/onnxruntime_pybind_state_common.cc @@ -47,7 +47,11 @@ onnxruntime::ArenaExtendStrategy arena_extend_strategy = onnxruntime::ArenaExten #endif #ifdef USE_MIGRAPHX -onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo migx_external_allocator_info{}; +namespace migraphx::external { +void* alloc_fn{nullptr}; +void* free_fn{nullptr}; +void* empty_cache_fn{nullptr}; +} // namespace migraphx::external #endif #if defined(ENABLE_DLPACK) diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index a73b701a36ddb..b4a33e798f942 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -226,9 +226,14 @@ extern onnxruntime::ArenaExtendStrategy arena_extend_strategy; namespace onnxruntime { ProviderInfo_MIGraphX* TryGetProviderInfo_MIGraphX(); ProviderInfo_MIGraphX& GetProviderInfo_MIGraphX(); -namespace python { -extern onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo migx_external_allocator_info; -} // namespace python +namespace python::migraphx::external { +extern void* alloc_fn; +extern void* free_fn; +extern void* empty_cache_fn; +inline bool UseExternalAllocator() { + return alloc_fn != nullptr && free_fn != nullptr; +} +} // namespace python::migraphx::external } // namespace onnxruntime #endif diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index eebe425f04c6a..8df845d7ea5d6 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -80,7 +80,7 @@ std::unique_ptr TensorrtExecutionProviderWithOptions(const O std::unique_ptr DefaultMIGraphXExecutionProvider() { #ifdef USE_MIGRAPHX - OrtMIGraphXProviderOptions params{ + constexpr OrtMIGraphXProviderOptions params{ 0, 0, 0, @@ -95,6 +95,9 @@ std::unique_ptr DefaultMIGraphXExecutionProvider() { SIZE_MAX, 0, 0, + nullptr, + nullptr, + nullptr, nullptr}; return MIGraphXProviderFactoryCreator::Create(¶ms)->CreateProvider(); #else diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index 36173c9489513..a7d8415c1149c 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -828,7 +828,7 @@ def generate_files(line_list, args): + '\\native" />' ) - # process all other library dependencies + # process all other library dependencies if is_cpu_package or is_cuda_gpu_package or is_migraphx_package or is_dml_package or is_mklml_package: # Process dnnl dependency if os.path.exists(os.path.join(args.native_build_path, nuget_dependencies["dnnl"])): From 7d8e4f2c47ed4f541d62af784d46c700ee038bb3 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Wed, 21 May 2025 14:59:04 +0200 Subject: [PATCH 22/46] Bundle MIGraphX and HIP runtime binaries in C# NuGet and Python wheel --- cmake/onnxruntime_providers_migraphx.cmake | 31 +++++++++++ cmake/onnxruntime_python.cmake | 15 ++++++ setup.py | 52 ++++++++++++++++++- .../nuget/generate_nuspec_for_native_nuget.py | 45 ++++++++++++++++ 4 files changed, 141 insertions(+), 2 deletions(-) diff --git a/cmake/onnxruntime_providers_migraphx.cmake b/cmake/onnxruntime_providers_migraphx.cmake index 626ac211d0a6c..8cb5dcf95155a 100644 --- a/cmake/onnxruntime_providers_migraphx.cmake +++ b/cmake/onnxruntime_providers_migraphx.cmake @@ -67,6 +67,37 @@ endif() endif() + if(CMAKE_SYSTEM_NAME STREQUAL "Windows") + foreach(file migraphx-hiprtc-driver.exe migraphx.dll migraphx_c.dll migraphx_cpu.dll migraphx_device.dll migraphx_gpu.dll migraphx_onnx.dll migraphx_tf.dll) + set(_source "${AMD_MIGRAPHX_HOME}/bin/${file}") + if(EXISTS "${_source}") + add_custom_command(TARGET onnxruntime_providers_migraphx + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${_source} $) + set(_target "$/${file}") + list(APPEND _migraphx_targets ${_target}) + endif() + endforeach() + set(MIGRAPHX_LIB_FILES ${_migraphx_targets} CACHE INTERNAL "" FORCE) + install(FILES ${_migraphx_targets} + DESTINATION ${CMAKE_INSTALL_BINDIR}) + get_property(_amdhip64_location TARGET hip::amdhip64 PROPERTY IMPORTED_LOCATION_RELEASE) + cmake_path(GET _amdhip64_location PARENT_PATH _hipsdk_path) + foreach(file amd_comgr0602.dll amd_comgr0604.dll amd_comgr0700.dll hiprtc0602.dll hiprtc0604.dll hiprtc0700.dll hiprtc-builtins0602.dll hiprtc-builtins0604.dll hiprtc-builtins0700.dll) + set(_source "${_hipsdk_path}/${file}") + if(EXISTS "${_source}") + add_custom_command(TARGET onnxruntime_providers_migraphx + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${_source} $) + set(_target "$/${file}") + list(APPEND _hipsdk_targets ${_target}) + endif() + endforeach() + set(HIPSDK_LIB_FILES ${_hipsdk_targets} CACHE INTERNAL "" FORCE) + install(FILES ${_hipsdk_targets} + DESTINATION ${CMAKE_INSTALL_BINDIR}) + endif() + install(TARGETS onnxruntime_providers_migraphx EXPORT onnxruntime_providers_migraphxTargets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index c5c85dff96411..ae976abe62fd8 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -740,6 +740,21 @@ if (onnxruntime_USE_OPENVINO) ) endif() +if (onnxruntime_USE_MIGRAPHX) + if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${MIGRAPHX_LIB_FILES} + $/onnxruntime/capi/) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${HIPSDK_LIB_FILES} + $/onnxruntime/capi/) + endif() +endif() + if (onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD diff --git a/setup.py b/setup.py index e49b31572f156..404e65c49e54e 100644 --- a/setup.py +++ b/setup.py @@ -285,7 +285,28 @@ def run(self): self._rewrite_ld_preload(to_preload_cann) else: - pass + hipsdk_dependencies = [ + "amd_comgr0602.dll", + "amd_comgr0604.dll", + "amd_comgr0700.dll", + "hiprtc0602.dll", + "hiprtc0604.dll", + "hiprtc0700.dll", + "hiprtc-builtins0602.dll", + "hiprtc-builtins0604.dll", + "hiprtc-builtins0700.dll", + ] + + migraphx_dependencies = [ + "migraphx-hiprtc-driver.exe", + "migraphx.dll", + "migraphx_c.dll", + "migraphx_cpu.dll", + "migraphx_device.dll", + "migraphx_gpu.dll", + "migraphx_onnx.dll", + "migraphx_tf.dll", + ] _bdist_wheel.run(self) if is_manylinux and not disable_auditwheel_repair and not is_openvino and not is_qnn: @@ -293,7 +314,14 @@ def run(self): file = glob(path.join(self.dist_dir, "*linux*.whl"))[0] logger.info("repairing %s for manylinux1", file) auditwheel_cmd = ["auditwheel", "-v", "repair", "-w", self.dist_dir, file] - for i in cuda_dependencies + rocm_dependencies + tensorrt_dependencies + cann_dependencies: + for i in ( + cuda_dependencies + + hipsdk_dependencies + + rocm_dependencies + + migraphx_dependencies + + tensorrt_dependencies + + cann_dependencies + ): auditwheel_cmd += ["--exclude", i] logger.info("Running %s", " ".join([shlex.quote(arg) for arg in auditwheel_cmd])) try: @@ -436,6 +464,26 @@ def finalize_options(self): libs.extend(qnn_deps) if nightly_build: libs.extend(["onnxruntime_pywrapper.dll"]) + migraphx_deps = [ + "amd_comgr0602.dll", + "amd_comgr0604.dll", + "amd_comgr0700.dll", + "hiprtc0602.dll", + "hiprtc0604.dll", + "hiprtc0700.dll", + "hiprtc-builtins0602.dll", + "hiprtc-builtins0604.dll", + "hiprtc-builtins0700.dll", + "migraphx-hiprtc-driver.exe", + "migraphx.dll", + "migraphx_c.dll", + "migraphx_cpu.dll", + "migraphx_device.dll", + "migraphx_gpu.dll", + "migraphx_onnx.dll", + "migraphx_tf.dll", + ] + libs.extend(migraphx_deps) if is_manylinux: if is_openvino: diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index a7d8415c1149c..ead240a7cef1b 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -396,6 +396,23 @@ def generate_files(line_list, args): "cuda_ep_shared_lib": "onnxruntime_providers_cuda.dll", "qnn_ep_shared_lib": "onnxruntime_providers_qnn.dll", "migraphx_ep_shared_lib": "onnxruntime_providers_migraphx.dll", + "amd_comgr0602": "amd_comgr0602.dll", + "amd_comgr0604": "amd_comgr0604.dll", + "amd_comgr0700": "amd_comgr0700.dll", + "hiprtc0602": "hiprtc0602.dll", + "hiprtc0604": "hiprtc0604.dll", + "hiprtc0700": "hiprtc0700.dll", + "hiprtc-builtins0602": "hiprtc-builtins0602.dll", + "hiprtc-builtins0604": "hiprtc-builtins0604.dll", + "hiprtc-builtins0700": "hiprtc-builtins0700.dll", + "migraphx-hiprtc-driver": "migraphx-hiprtc-driver.exe", + "migraphx": "migraphx.dll", + "migraphx_c": "migraphx_c.dll", + "migraphx_cpu": "migraphx_cpu.dll", + "migraphx_device": "migraphx_device.dll", + "migraphx_gpu": "migraphx_gpu.dll", + "migraphx_onnx": "migraphx_onnx.dll", + "migraphx_tf": "migraphx_tf", "onnxruntime_perf_test": "onnxruntime_perf_test.exe", "onnx_test_runner": "onnx_test_runner.exe", } @@ -818,6 +835,34 @@ def generate_files(line_list, args): + '\\native" />' ) + if is_windows_build: + native_build_path = Path(args.native_build_path) + + def _files_list_append(key: str): + path = native_build_path / nuget_dependencies[key] + if path.exists(): + files_list.append( + "' + ) + + _files_list_append("amd_comgr0602") + _files_list_append("amd_comgr0604") + _files_list_append("amd_comgr0700") + _files_list_append("hiprtc0602") + _files_list_append("hiprtc0604") + _files_list_append("hiprtc0700") + _files_list_append("hiprtc-builtins0602") + _files_list_append("hiprtc-builtins0604") + _files_list_append("hiprtc-builtins0700") + _files_list_append("migraphx-hiprtc-driver") + _files_list_append("migraphx") + _files_list_append("migraphx_c") + _files_list_append("migraphx_cpu") + _files_list_append("migraphx_device") + _files_list_append("migraphx_gpu") + _files_list_append("migraphx_onnx") + _files_list_append("migraphx_tf") + if is_dml_package: files_list.append( " Date: Mon, 23 Jun 2025 13:52:01 +0200 Subject: [PATCH 24/46] Remove redundant initialization of execution provider --- .../migraphx/migraphx_execution_provider.cc | 222 ++++++------------ .../migraphx/migraphx_execution_provider.h | 8 +- 2 files changed, 79 insertions(+), 151 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 460abaa06e629..b1a6ee33d1166 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -113,31 +113,78 @@ std::shared_ptr MIGraphXExecutionProvider::GetKernelRegistry() c return s_kernel_registry; } +static std::string_view GetArenaExtendStrategyName(ArenaExtendStrategy strategy) { + switch (strategy) { + case ArenaExtendStrategy::kNextPowerOfTwo: + return "kNextPowerOfTwo"; + case ArenaExtendStrategy::kSameAsRequested: + return "kSameAsRequested"; + default: + return "Unknown"; + } +} + +#define GET_ENV(variable, value, ...) \ + const auto value##env{GetEnvironmentVar(variable)}; \ + if (!value##env.empty()) { \ + __VA_ARGS__; \ + LOGS_DEFAULT(INFO) << "\n " << variable << ": " << value##env; \ + } + +#define GET_ENV_BOOL(variable, value) \ + GET_ENV(variable, value, value = std::stoi(value##env) != 0) + +#define GET_ENV_STRING(variable, value) \ + GET_ENV(variable, value, value = value##env) + MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, - info.device_id)}, - device_id_{info.device_id} { + : IExecutionProvider{kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, info.device_id)}, + device_id_{info.device_id}, + fp16_enable_{info.fp16_enable}, +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && (HIP_VERSION_MINOR > 4 || (HIP_VERSION_MINOR == 4 && HIP_VERSION_PATCH >= 2))) + bf16_enable_{info.bf16_enable}, +#endif +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4) + fp8_enable_{info.fp8_enable}, +#endif + int8_enable_{info.int8_enable}, + model_cache_path_{info.model_cache_dir}, + t_{info.target_device.c_str()}, + exhaustive_tune_{info.exhaustive_tune}, + metadef_id_generator_{ModelMetadefIdGenerator::Create()}, + external_alloc_{info.external_alloc}, + external_free_{info.external_free}, + external_empty_cache_{info.external_empty_cache} { InitProviderOrtApi(); - get_flags_from_session_info(info); - metadef_id_generator_ = ModelMetadefIdGenerator::Create(); - get_flags_from_env(); -} -MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { -} + // Set GPU device to be used and read device properties for feature usage. -void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info) { - // Set GPU device to be used HIP_CALL_THROW(hipSetDevice(device_id_)); - HIP_CALL_THROW(hipGetDeviceProperties(&device_prop_, info.device_id)); - t_ = migraphx::target(info.target_device.c_str()); + HIP_CALL_THROW(hipGetDeviceProperties(&device_prop_, device_id_)); - // Quantization - fp16_enable_ = info.fp16_enable; + // Overwrite initialized values with values from environment variables. -#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4 && HIP_VERSION_PATCH >= 2) - bf16_enable_ = info.bf16_enable; + LOGS_DEFAULT(WARNING) << "[MIGraphX EP] MIGraphX ENV Override Variables Set:"; + GET_ENV_BOOL(migraphx_env_vars::kFP16Enable, fp16_enable_); +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && (HIP_VERSION_MINOR > 4 || (HIP_VERSION_MINOR == 4 && HIP_VERSION_PATCH >= 2))) + GET_ENV_BOOL(migraphx_env_vars::kBF16Enable, bf16_enable_); +#endif +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4) + GET_ENV_BOOL(migraphx_env_vars::kFP8Enable, fp8_enable_); +#endif + GET_ENV_BOOL(migraphx_env_vars::kINT8Enable, int8_enable_); + GET_ENV(migraphx_env_vars::kINT8CalibrationTableName, int8_calibration_cache_name_); + GET_ENV(migraphx_env_vars::kINT8UseNativeMIGraphXCalibrationTable, int8_use_native_migraphx_calibration_table_); + GET_ENV_STRING(migraphx_env_vars::kCachePath, calibration_cache_path_); + GET_ENV_STRING(migraphx_env_vars::kModelCachePath, model_cache_path_); + GET_ENV_BOOL(migraphx_env_vars::kDumpModelOps, dump_model_ops_); + GET_ENV_BOOL(migraphx_env_vars::kExhaustiveTune, exhaustive_tune_); + + // Verify configuration correctness and adjust accordingly. + +#if HIP_VERSION_MAJOR < 6 || (HIP_VERSION_MAJOR == 6 && (HIP_VERSION_MINOR < 4 || (HIP_VERSION_MINOR == 4 && HIP_VERSION_PATCH < 2))) + LOGS_DEFAULT(WARNING) << "MIGraphX: BF16 Quantization requires ROCm 6.4.2 or greater"; + bf16_enable_ = false; #endif if (bf16_enable_ && fp16_enable_) { @@ -146,23 +193,20 @@ void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecut LOGS_DEFAULT(FATAL) << "MIGraphX: BF16 and FP16 Quantization Mutually exclusive. Ignoring both Quantization flags"; } -#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4) - fp8_enable_ = info.fp8_enable; -#else +#if HIP_VERSION_MAJOR < 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR < 4) LOGS_DEFAULT(WARNING) << "MIGraphX: FP8 Quantization requires ROCm 6.4 or greater"; fp8_enable_ = false; #endif - int8_enable_ = info.int8_enable; if (int8_enable_ && fp8_enable_) { - int8_enable_ = false; - fp8_enable_ = false; LOGS_DEFAULT(FATAL) << "MIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; } if (int8_enable_ ^ fp8_enable_) { - int8_calibration_cache_name_ = info.int8_calibration_table_name; - int8_use_native_migraphx_calibration_table_ = info.int8_use_native_calibration_table; + int8_calibration_cache_name_ = + int8_calibration_cache_name_env.empty() ? info.int8_calibration_table_name : int8_calibration_cache_name_env; + int8_use_native_migraphx_calibration_table_ = + int8_use_native_migraphx_calibration_table_env.empty() ? info.int8_use_native_calibration_table : std::stoi(int8_use_native_migraphx_calibration_table_env) != 0; } if (int8_enable_ || fp8_enable_) { @@ -170,136 +214,24 @@ void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecut } // Load INT8 calibration table - std::unordered_map dynamic_range_map; if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { - const auto calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); + std::unordered_map dynamic_range_map; + auto calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { throw std::runtime_error("Session Failed to read INT8 calibration table " + calibration_cache_path.string()); } } - // Save/load migraphx compiled models - model_cache_path_ = info.model_cache_dir; - - exhaustive_tune_ = info.exhaustive_tune; - - LOGS_DEFAULT(WARNING) << "[MIGraphX EP] MIGraphX provider Session Options:"; - print_migraphx_ep_flags(); -} - -void MIGraphXExecutionProvider::get_flags_from_env() { - LOGS_DEFAULT(WARNING) << "\n[MIGraphX EP] MIGraphX ENV Override Variables Set:"; - // whether fp16 is enable - const std::string fp16_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP16Enable); - if (!fp16_enable_env.empty()) { - fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP16_ENABLE: " << fp16_enable_; - } - - const std::string bf16_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kBF16Enable); - if (!bf16_enable_env.empty()) { -#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4 && HIP_VERSION_PATCH >= 2) - bf16_enable_ = (std::stoi(bf16_enable_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_BF16_ENABLE: " << fp16_enable_; -#else - LOGS_DEFAULT(WARNING) << "MIGraphX: BF16 Quantization requires ROCm 6.4.2 or greater"; - bf16_enable_ = false; -#endif - } - - if (bf16_enable_ && fp16_enable_) { - LOGS_DEFAULT(FATAL) << "\nMIGraphX: FP16 and BF16 Quantization Mutually exclusive. Ignoring both flags"; - } - - // whether fp8 quantization is enabled - const std::string fp8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP8Enable); - if (!fp8_enable_env.empty()) { -#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4) - fp8_enable_ = (std::stoi(fp8_enable_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP8_ENABLE: " << fp8_enable_; -#else - LOGS_DEFAULT(WARNING) << "MIGraphX: FP8 Quantization requires ROCm 6.4 or greater"; - fp8_enable = false; -#endif - } - - // whether int8 is enabled - const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8Enable); - if (!int8_enable_env.empty()) { - int8_enable_ = (std::stoi(int8_enable_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_ENABLE: " << int8_enable_; - } - - if (int8_enable_ && fp8_enable_) { - LOGS_DEFAULT(FATAL) << "\nMIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; - } - - if (int8_enable_ || fp8_enable_) { - const std::string int8_calibration_cache_name_env = - onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8CalibrationTableName); - if (!int8_calibration_cache_name_env.empty()) { - int8_calibration_cache_name_ = int8_calibration_cache_name_env; - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CALIBRATION_TABLE_NAME: " << int8_calibration_cache_name_; - } - - const std::string cache_path = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kCachePath); - if (!cache_path.empty()) { - calibration_cache_path_ = cache_path; - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CACHE_PATH: " << calibration_cache_path_.string(); - } - - const std::string int8_use_native_migraphx_calibration_table_env = - onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8UseNativeMIGraphXCalibrationTable); - if (!int8_use_native_migraphx_calibration_table_env.empty()) { - int8_use_native_migraphx_calibration_table_ = - (std::stoi(int8_use_native_migraphx_calibration_table_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE: " - << int8_use_native_migraphx_calibration_table_; - } - } - - if (int8_enable_ || fp8_enable_) { - int8_calibration_cache_available_ = !int8_calibration_cache_name_.empty(); - } - - // Load INT8 calibration table - std::unordered_map dynamic_range_map; - if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { - const auto calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); - if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { - throw std::runtime_error("ENV Failed to read calibration table " + calibration_cache_path.string()); - } - } - - // Save/load migraphx compiled models - const auto model_cache_path_env = GetEnvironmentVar(migraphx_env_vars::kModelCachePath); - if (!model_cache_path_env.empty()) { - model_cache_path_ = GetEnvironmentVar(migraphx_env_vars::kModelCachePath); - LOGS_DEFAULT(INFO) << "\n" - << migraphx_env_vars::kModelCachePath << ": " << model_cache_path_.string(); - } - - // dump unsupported ops - const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kDumpModelOps); - if (!dump_model_ops_env.empty()) { - dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_DUMP_MODEL_OPS: " << dump_model_ops_; - } - - // Allow for exhaustive tune during compile - const std::string exhaustive_tune_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kExhaustiveTune); - if (!exhaustive_tune_env.empty()) { - exhaustive_tune_ = (std::stoi(exhaustive_tune_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_EXHAUSTIVE_TUNE_OPS: " << exhaustive_tune_; - } -} + // Print configured options for the session. -void MIGraphXExecutionProvider::print_migraphx_ep_flags() { - LOGS_DEFAULT(VERBOSE) << "\n " << migraphx_provider_option::kDeviceId << ": " << device_id_ + LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider Session Options:" + << "\n " << migraphx_provider_option::kDeviceId << ": " << device_id_ << "\n " << migraphx_provider_option::kFp16Enable << ": " << fp16_enable_ << "\n " << migraphx_provider_option::kBf16Enable << ": " << bf16_enable_ << "\n " << migraphx_provider_option::kFp8Enable << ": " << fp8_enable_ << "\n " << migraphx_provider_option::kInt8Enable << ": " << int8_enable_ + << "\n " << migraphx_provider_option::kMemLimit << ": " << mem_limit_ + << "\n " << migraphx_provider_option::kArenaExtendStrategy << ": " << GetArenaExtendStrategyName(arena_extend_strategy_) << "\n dump_model_ops: " << dump_model_ops_ << "\n " << migraphx_provider_option::kExhaustiveTune << ": " << exhaustive_tune_ << "\n " << migraphx_provider_option::kInt8CalibTable << ": " << int8_calibration_cache_name_ diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 189cee74f264e..927929c60419a 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -66,11 +66,7 @@ struct MIGraphXFuncState { class MIGraphXExecutionProvider : public IExecutionProvider { public: explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info); - ~MIGraphXExecutionProvider() override; - - void get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info); - void get_flags_from_env(); - void print_migraphx_ep_flags(); + ~MIGraphXExecutionProvider() override = default; Status Sync() const override; @@ -132,7 +128,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { migraphx::target t_; std::mutex mgx_mu_; hipStream_t stream_ = nullptr; - hipDeviceProp_t device_prop_; + hipDeviceProp_t device_prop_{}; bool exhaustive_tune_ = false; mutable std::filesystem::path model_path_{}; size_t mem_limit_{std::numeric_limits::max()}; From e6143470a26a07bbe86fd610f54154305eeebd5b Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Mon, 28 Jul 2025 21:47:28 +0200 Subject: [PATCH 25/46] Remove redundant allocator static method --- .../migraphx/migraphx_execution_provider.cc | 34 ------------------- .../migraphx/migraphx_execution_provider.h | 3 -- .../migraphx/migraphx_provider_factory.cc | 24 ++++++++++++- 3 files changed, 23 insertions(+), 38 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index b1a6ee33d1166..5e2d611443918 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -240,40 +240,6 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv << "\n " << migraphx_provider_option::kModelCacheDir << ": " << model_cache_path_; } -AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, - size_t migx_mem_limit, - ArenaExtendStrategy arena_extend_strategy, - void* alloc_fn, - void* free_fn, - void* empty_cache_fn, - const OrtArenaCfg* default_memory_arena_cfg) { - if (alloc_fn != nullptr && free_fn != nullptr) { - const AllocatorCreationInfo default_memory_info( - [alloc_fn, free_fn, empty_cache_fn](OrtDevice::DeviceId id) { - return std::make_unique(id, HIP, alloc_fn, free_fn, empty_cache_fn); - }, - device_id, - false); - - return CreateAllocator(default_memory_info); - } else { - AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId id) { - return std::make_unique(id, HIP); - }, - device_id, - true, - {default_memory_arena_cfg ? *default_memory_arena_cfg - : OrtArenaCfg(migx_mem_limit, static_cast(arena_extend_strategy), - -1, -1, -1, -1L)}, - // make it stream aware - true); - - // ROCM malloc/free is expensive so always use an arena - return CreateAllocator(default_memory_info); - } -} - std::vector MIGraphXExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo default_memory_info( [](OrtDevice::DeviceId device_id) { diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 927929c60419a..9d9e1d0e1dd1e 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -86,9 +86,6 @@ class MIGraphXExecutionProvider : public IExecutionProvider { std::shared_ptr GetKernelRegistry() const override; std::unique_ptr GetDataTransfer() const override; - static AllocatorPtr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t mem_limit, ArenaExtendStrategy arena_extend_strategy, - void* alloc_fn, void* free_fn, void* empty_cache_fn, const OrtArenaCfg* arena_cfg); - std::unique_ptr GetSubGraph(const std::vector& graph_nodes_index, const GraphViewer& graph) const; void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 067e27caf7229..6e990522611cf 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -71,7 +71,29 @@ struct ProviderInfo_MIGraphX_Impl final : ProviderInfo_MIGraphX { std::shared_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, void* alloc_fn, void* free_fn, void* empty_cache_fn, const OrtArenaCfg* default_memory_arena_cfg) override { - return MIGraphXExecutionProvider::CreateMIGraphXAllocator(device_id, mem_limit, arena_extend_strategy, alloc_fn, free_fn, empty_cache_fn, default_memory_arena_cfg); + if (alloc_fn != nullptr && free_fn != nullptr) { + AllocatorCreationInfo default_memory_info{ + [alloc_fn, free_fn, empty_cache_fn](OrtDevice::DeviceId id) { + return std::make_unique(id, HIP, alloc_fn, free_fn, empty_cache_fn); + }, + device_id, false}; + + return CreateAllocator(default_memory_info); + } + AllocatorCreationInfo default_memory_info{ + [](OrtDevice::DeviceId id) { + return std::make_unique(id, HIP); + }, + device_id, + true, + {default_memory_arena_cfg ? *default_memory_arena_cfg + : OrtArenaCfg(mem_limit, static_cast(arena_extend_strategy), + -1, -1, -1, -1L)}, + // make it stream aware + true}; + + // ROCM malloc/free is expensive so always use an arena + return CreateAllocator(default_memory_info); } } g_info; From b2bf6ae02627a892230e2350664fe4a4c79049ae Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Tue, 29 Jul 2025 16:07:44 +0200 Subject: [PATCH 26/46] MIGraphX cherry-picks for WCR --- .../migraphx/migraphx_provider_factory.cc | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 6e990522611cf..56ed5f262eeb1 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -223,9 +223,12 @@ struct MigraphXEpFactory : OrtEpFactory { : ort_api{ort_api_in}, default_logger{default_logger_in}, ep_name{ep_name}, ort_hw_device_type{hw_type} { GetName = GetNameImpl; GetVendor = GetVendorImpl; + GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; + GetVendorId = GetVendorIdImpl; + CreateDataTransfer = CreateDataTransferImpl; } // Returns the name for the EP. Each unique factory configuration must have a unique name. @@ -240,6 +243,23 @@ struct MigraphXEpFactory : OrtEpFactory { return factory->vendor.c_str(); } + static const char* GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->ep_version.c_str(); + } + + static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->vendor_id; + } + + static OrtStatus* CreateDataTransferImpl(OrtEpFactory* this_ptr, + OrtDataTransferImpl** data_transfer) noexcept { + ORT_UNUSED_PARAMETER(this_ptr); + *data_transfer = nullptr; // return nullptr to indicate that this EP does not support data transfer. + return nullptr; + } + // Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports. // An EP created with this factory is expected to be able to execute a model with *all* supported // hardware devices at once. A single instance of MigraphX EP is not currently setup to partition a model among @@ -288,7 +308,7 @@ struct MigraphXEpFactory : OrtEpFactory { const OrtLogger& default_logger; const std::string ep_name; const std::string vendor{"AMD"}; - + const std::string ep_version{"0.1.0"}; const uint32_t vendor_id{0x1002}; const OrtHardwareDeviceType ort_hw_device_type; // Supported OrtHardwareDevice }; From 71fc0b9974f555b8232a431ddaf7f79c190828e2 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Wed, 2 Jul 2025 13:23:20 +0200 Subject: [PATCH 27/46] Remove ROCm EP dependencies after drop --- tools/ci_build/build.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index e6a2d0741c859..0d51f66df33aa 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -723,8 +723,6 @@ def generate_build_tree( cmake_args += ["-Donnxruntime_ENABLE_WEBASSEMBLY_RELAXED_SIMD=ON"] if args.use_migraphx: cmake_args.append("-Donnxruntime_MIGRAPHX_HOME=" + migraphx_home) - cmake_args.append("-Donnxruntime_ROCM_HOME=" + rocm_home) - cmake_args.append("-Donnxruntime_ROCM_VERSION=" + args.rocm_version) if args.use_tensorrt: cmake_args.append("-Donnxruntime_TENSORRT_HOME=" + tensorrt_home) From 64eff87910bd5a36d7b3aea7df554aeb3f7862bf Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Mon, 4 Aug 2025 22:27:50 +0200 Subject: [PATCH 28/46] changed back the type to bool --- include/onnxruntime/core/session/onnxruntime_c_api.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 99393d7c8d719..dd5801f0fe101 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -754,11 +754,11 @@ typedef struct OrtMIGraphXProviderOptions { int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, nonzero = true const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name - int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, noznero = true + int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, nonzero = true const char* migraphx_save_model_path; // migraphx model path name - int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true + int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, nonzero = true const char* migraphx_load_model_path; // migraphx model path name - int migraphx_exhaustive_tune; // MIGraphX tuned compile. Default = false, nonzero = true + bool migraphx_exhaustive_tune; // MIGraphX tuned compile. Default = false, nonzero = true /** \brief MIGraphX memory limit (To use all possible memory pass in maximum size_t) * Defaults to SIZE_MAX. From e6e5e54dd578d6935e676466932860f3c5f5d979 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Tue, 5 Aug 2025 13:40:58 +0200 Subject: [PATCH 29/46] do not modify provider ABI --- .../providers/shared_library/provider_bridge_provider.cc | 4 ++-- .../core/providers/shared_library/provider_interfaces.h | 5 ----- onnxruntime/core/session/provider_bridge_ort.cc | 2 -- 3 files changed, 2 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 765701689511b..d690cf31072d2 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -791,11 +791,11 @@ Status LoadDynamicLibrary(onnxruntime::PathString library_name) { #ifdef _WIN32 std::string ToUTF8String(std::wstring_view s) { - return g_host->ToUTF8String(s); + return g_host->ToUTF8String(std::wstring{s}); } std::wstring ToWideString(std::string_view s) { - return g_host->ToWideString(s); + return g_host->ToWideString(std::string{s}); } #endif // _WIN32 } // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index b6c4bccfe4e00..5c9c1a0ae163f 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -1354,11 +1354,6 @@ struct ProviderHost { virtual std::unique_ptr ModelMetadefIdGenerator__construct() = 0; virtual void ModelMetadefIdGenerator__operator_delete(ModelMetadefIdGenerator* p) = 0; virtual int ModelMetadefIdGenerator__GenerateId(const ModelMetadefIdGenerator* p, const GraphViewer& graph_viewer, HashValue& model_hash) = 0; - -#ifdef _WIN32 - virtual std::string ToUTF8String(std::wstring_view s) = 0; - virtual std::wstring ToWideString(std::string_view s) = 0; -#endif }; #if defined(_MSC_VER) && !defined(__clang__) diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 9b756b72d6c6e..e1160291269dd 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1732,8 +1732,6 @@ struct ProviderHostImpl : ProviderHost { #ifdef _WIN32 std::string ToUTF8String(const std::wstring& s) override { return onnxruntime::ToUTF8String(s); } std::wstring ToWideString(const std::string& s) override { return onnxruntime::ToWideString(s); } - std::string ToUTF8String(std::wstring_view s) override { return onnxruntime::ToUTF8String(s); } - std::wstring ToWideString(std::string_view s) override { return onnxruntime::ToWideString(s); } #endif ProviderHostCPU& GetProviderHostCPU() override { return onnxruntime::GetProviderHostCPU(); } From ed864386a156abc23a5c1a1e87650a5ec7596424 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Tue, 5 Aug 2025 13:46:19 +0200 Subject: [PATCH 30/46] use std::string_view variant to reduce binary size --- .../core/framework/provider_options_utils.h | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/include/onnxruntime/core/framework/provider_options_utils.h b/include/onnxruntime/core/framework/provider_options_utils.h index e2c25dde24054..badb7320ea49e 100644 --- a/include/onnxruntime/core/framework/provider_options_utils.h +++ b/include/onnxruntime/core/framework/provider_options_utils.h @@ -83,10 +83,7 @@ class ProviderOptionsParser { template ProviderOptionsParser& AddValueParser( const std::string& name, ValueParserType value_parser) { - ORT_ENFORCE( - value_parsers_.emplace(name, ValueParser{value_parser}).second, - "Provider option \"", name, "\" already has a value parser."); - return *this; + return AddValueParser(std::string_view{name}, value_parser); } template @@ -119,11 +116,7 @@ class ProviderOptionsParser { template ProviderOptionsParser& AddAssignmentToReference( const std::string& name, ValueType& dest) { - return AddValueParser( - name, - [&dest](const std::string& value_str) -> Status { - return ParseStringWithClassicLocale(value_str, dest); - }); + return AddAssignmentToReference(std::string_view{name}, dest); } template @@ -159,11 +152,7 @@ class ProviderOptionsParser { template ProviderOptionsParser& AddAssignmentToEnumReference( const std::string& name, const EnumNameMapping& mapping, EnumType& dest) { - return AddValueParser( - name, - [&mapping, &dest](const std::string& value_str) -> Status { - return NameToEnum(mapping, value_str, dest); - }); + return AddAssignmentToEnumReference(std::string_view{name}, mapping, dest); } template From 90f1ea41ca65dbed217b5e613090d2a9a2355e65 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Tue, 5 Aug 2025 13:51:45 +0200 Subject: [PATCH 31/46] review comment: use std::string directly --- onnxruntime/core/providers/shared_library/provider_api.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 1e4a94a63b749..a7fd83f10fe18 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -330,7 +330,7 @@ inline std::string GetEnvironmentVar(std::string_view var_name) { return GetEnvironmentVar(std::string{var_name}); } inline std::string GetEnvironmentVar(const char* var_name) { - return GetEnvironmentVar(std::string_view{var_name}); + return GetEnvironmentVar(std::string{var_name}); } namespace profiling { From b2bef03450a65bc15024cc2e51b5813b737a9e0e Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Tue, 5 Aug 2025 13:53:54 +0200 Subject: [PATCH 32/46] review comment: fix security bug in StrDup(wchar_t) --- onnxruntime/core/session/onnxruntime_c_api.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 70eee52252b5b..40bb721985acd 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1386,7 +1386,7 @@ char* onnxruntime::StrDup(std::string_view str, OrtAllocator* allocator) { } wchar_t* onnxruntime::StrDup(std::wstring_view str, OrtAllocator* allocator) { - auto* output_string = static_cast(allocator->Alloc(allocator, str.size() + 1)); + auto* output_string = static_cast(allocator->Alloc(allocator, (str.size() + 1) * sizeof(wchar_t))); memcpy(output_string, str.data(), str.size() * sizeof(wchar_t)); output_string[str.size()] = '\0'; return output_string; From b46d46d35de9653390aa638d30298e85102c207e Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Tue, 5 Aug 2025 14:05:01 +0200 Subject: [PATCH 33/46] review comment: CANN update --- onnxruntime/python/onnxruntime_pybind_state.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 5fea54a452260..5991825542865 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1870,9 +1870,12 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra vendor = OrtDevice::VendorIds::NVIDIA; #elif USE_ROCM || USE_MIGRAPHX vendor = OrtDevice::VendorIds::AMD; +#endif + } else if (type == OrtDevice::NPU) { +#if USE_CANN + vendor = OrtDevice::VendorIds::HUAWEI; #endif } - return OrtDevice(type, mem_type, vendor, device_id); }), R"pbdoc(Constructor with vendor_id defaulted to 0 for backward compatibility.)pbdoc") From e06f1eedd3c95888c8b76175e223aadc0f0d3d78 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Tue, 5 Aug 2025 14:31:27 +0200 Subject: [PATCH 34/46] review comment: remove unused dependencies --- setup.py | 33 +-------------------------------- 1 file changed, 1 insertion(+), 32 deletions(-) diff --git a/setup.py b/setup.py index 404e65c49e54e..d1feaee2ad9e3 100644 --- a/setup.py +++ b/setup.py @@ -284,44 +284,13 @@ def run(self): self._rewrite_ld_preload_tensorrt(to_preload_nv_tensorrt_rtx) self._rewrite_ld_preload(to_preload_cann) - else: - hipsdk_dependencies = [ - "amd_comgr0602.dll", - "amd_comgr0604.dll", - "amd_comgr0700.dll", - "hiprtc0602.dll", - "hiprtc0604.dll", - "hiprtc0700.dll", - "hiprtc-builtins0602.dll", - "hiprtc-builtins0604.dll", - "hiprtc-builtins0700.dll", - ] - - migraphx_dependencies = [ - "migraphx-hiprtc-driver.exe", - "migraphx.dll", - "migraphx_c.dll", - "migraphx_cpu.dll", - "migraphx_device.dll", - "migraphx_gpu.dll", - "migraphx_onnx.dll", - "migraphx_tf.dll", - ] - _bdist_wheel.run(self) if is_manylinux and not disable_auditwheel_repair and not is_openvino and not is_qnn: assert self.dist_dir is not None file = glob(path.join(self.dist_dir, "*linux*.whl"))[0] logger.info("repairing %s for manylinux1", file) auditwheel_cmd = ["auditwheel", "-v", "repair", "-w", self.dist_dir, file] - for i in ( - cuda_dependencies - + hipsdk_dependencies - + rocm_dependencies - + migraphx_dependencies - + tensorrt_dependencies - + cann_dependencies - ): + for i in cuda_dependencies + rocm_dependencies + tensorrt_dependencies + cann_dependencies: auditwheel_cmd += ["--exclude", i] logger.info("Running %s", " ".join([shlex.quote(arg) for arg in auditwheel_cmd])) try: From d8237537f4320acd68912aa81c65c9cf8f8e7ef2 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Tue, 5 Aug 2025 14:36:47 +0200 Subject: [PATCH 35/46] review comment: fix source code formatting --- onnxruntime/python/onnxruntime_pybind_mlvalue.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 2f49634323d1a..893b607b0e18c 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -231,8 +231,11 @@ AllocatorPtr GetMIGraphXAllocator(OrtDevice::DeviceId id) { if (id_to_allocator_map->find(id) == id_to_allocator_map->end()) { // TODO: Expose knobs so that users can set fields associated with OrtArenaCfg so that we can pass it to the following method - id_to_allocator_map->insert({id, GetProviderInfo_MIGraphX().CreateMIGraphXAllocator(id, gpu_mem_limit, arena_extend_strategy, - migraphx::external::alloc_fn, migraphx::external::free_fn, migraphx::external::empty_cache_fn, nullptr)}); + id_to_allocator_map->insert( + {id, GetProviderInfo_MIGraphX().CreateMIGraphXAllocator( + id, gpu_mem_limit, arena_extend_strategy, + migraphx::external::alloc_fn, migraphx::external::free_fn, migraphx::external::empty_cache_fn, + nullptr)}); } return (*id_to_allocator_map)[id]; From 27e60f8ef79ff1088364bdaf4bce6b6d241bcef4 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Tue, 5 Aug 2025 14:55:43 +0200 Subject: [PATCH 36/46] review comment: fix security bug --- .../core/providers/migraphx/migraphx_call.cc | 34 ++++++++++--------- .../core/providers/migraphx/migraphx_call.h | 2 +- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.cc b/onnxruntime/core/providers/migraphx/migraphx_call.cc index 61e41ab4c6284..997f318fc640c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_call.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_call.cc @@ -17,24 +17,27 @@ namespace onnxruntime { +namespace { template -const char* RocmErrString(ERRTYPE x) { +std::string_view RocmErrString(ERRTYPE x) { ORT_NOT_IMPLEMENTED(); } #define CASE_ENUM_TO_STR(x) \ - case x: \ - return #x +case x: \ +return #x template <> -const char* RocmErrString(hipError_t x) { +std::string_view RocmErrString(hipError_t x) { (void)hipDeviceSynchronize(); - return hipGetErrorString(x); + return std::string_view{hipGetErrorString(x)}; } +} // namespace + template std::conditional_t RocmCall( - ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line) { + ERRTYPE retCode, std::string_view exprString, std::string_view libName, ERRTYPE successCode, std::string_view msg, std::string_view file, const int line) { if (retCode != successCode) { try { #ifdef _WIN32 @@ -47,17 +50,16 @@ std::conditional_t RocmCall( int currentHipDevice; (void)hipGetDevice(¤tHipDevice); (void)hipGetLastError(); // clear last HIP error - static char str[1024]; - snprintf(str, sizeof(str), "%s failure %d: %s ; GPU=%d ; hostname=%s ; file=%s ; line=%d ; expr=%s; %s", - libName, static_cast(retCode), RocmErrString(retCode), currentHipDevice, - hostname.c_str(), - file, line, exprString, msg); + std::stringstream ss; + ss << libName << " failure " << static_cast(retCode) << ": " << RocmErrString(retCode) + << "; GPU=" << currentHipDevice << "; hostname=" << hostname << "; file=" << file << "; line=" << line + << "; expr=" << exprString << "; " << msg; if constexpr (THRW) { // throw an exception with the error info - ORT_THROW(str); + ORT_THROW(ss.str()); } else { - LOGS_DEFAULT(ERROR) << str; - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, str); + LOGS_DEFAULT(ERROR) << ss.str(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ss.str()); } } catch (const std::exception& e) { // catch, log, and rethrow since HIP code sometimes hangs in destruction, so we'd never get to see the error if constexpr (THRW) { @@ -73,7 +75,7 @@ std::conditional_t RocmCall( } } -template Status RocmCall(hipError_t retCode, const char* exprString, const char* libName, hipError_t successCode, const char* msg, const char* file, const int line); -template void RocmCall(hipError_t retCode, const char* exprString, const char* libName, hipError_t successCode, const char* msg, const char* file, const int line); +template Status RocmCall(hipError_t retCode, std::string_view exprString, std::string_view libName, hipError_t successCode, std::string_view msg, std::string_view file, int line); +template void RocmCall(hipError_t retCode, std::string_view exprString, std::string_view libName, hipError_t successCode, std::string_view msg, std::string_view file, int line); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.h b/onnxruntime/core/providers/migraphx/migraphx_call.h index 64805784ba75f..9c3b5c79a947b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_call.h +++ b/onnxruntime/core/providers/migraphx/migraphx_call.h @@ -13,7 +13,7 @@ namespace onnxruntime { template std::conditional_t RocmCall( - ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line); + ERRTYPE retCode, std::string_view exprString, std::string_view libName, ERRTYPE successCode, std::string_view msg, std::string_view file, int line); #define HIP_CALL(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) #define HIP_CALL_THROW(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) From fcd919894d6ef56e719518aae11623ef1ed92a03 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Tue, 5 Aug 2025 15:40:20 +0200 Subject: [PATCH 37/46] fix Python wheel name for MIGraphX --- setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.py b/setup.py index d1feaee2ad9e3..03e8ef671f149 100644 --- a/setup.py +++ b/setup.py @@ -68,6 +68,7 @@ def parse_arg_remove_string(argv, arg_name_equal): is_cuda_version_12 = cuda_version.startswith("12.") elif parse_arg_remove_boolean(sys.argv, "--use_migraphx"): is_migraphx = True + package_name = "onnxruntime-migraphx" elif parse_arg_remove_boolean(sys.argv, "--use_openvino"): is_openvino = True package_name = "onnxruntime-openvino" @@ -90,8 +91,6 @@ def parse_arg_remove_string(argv, arg_name_equal): is_qnn = True package_name = "onnxruntime-qnn" qnn_version = parse_arg_remove_string(sys.argv, "--qnn_version=") -elif is_migraphx: - package_name = "onnxruntime-migraphx" if not nightly_build else "ort-migraphx-nightly" # PEP 513 defined manylinux1_x86_64 and manylinux1_i686 # PEP 571 defined manylinux2010_x86_64 and manylinux2010_i686 From 4b2acf1cb129444d3021accd703d44a64cbcd349 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Tue, 5 Aug 2025 16:44:37 +0200 Subject: [PATCH 38/46] Revert "The C# interface for MIGraphX execution provider" This reverts commit 48989b8e7d44924e93f162fd38f0f95aa1335772. --- .../NativeMethods.shared.cs | 237 +----------------- .../ProviderOptions.shared.cs | 136 ---------- .../SessionOptions.shared.cs | 61 +---- .../core/session/onnxruntime_c_api.h | 82 ------ .../core/providers/migraphx/migraphx_call.cc | 4 +- .../migraphx/migraphx_provider_factory.cc | 1 - onnxruntime/core/session/onnxruntime_c_api.cc | 8 +- onnxruntime/core/session/ort_apis.h | 11 - .../core/session/provider_bridge_ort.cc | 127 +--------- .../core/session/provider_registration.cc | 50 ---- setup.py | 2 + tools/ci_build/build.py | 5 - .../nuget/generate_nuspec_for_native_nuget.py | 88 +------ 13 files changed, 16 insertions(+), 796 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index a518020671621..8cca2b42e987a 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -31,12 +31,10 @@ public struct OrtApi public IntPtr CreateStatus; public IntPtr GetErrorCode; public IntPtr GetErrorMessage; - public IntPtr CreateEnv; public IntPtr CreateEnvWithCustomLogger; public IntPtr EnableTelemetryEvents; public IntPtr DisableTelemetryEvents; - public IntPtr CreateSession; public IntPtr CreateSessionFromArray; public IntPtr Run; @@ -72,7 +70,6 @@ public struct OrtApi public IntPtr SessionGetInputName; public IntPtr SessionGetOutputName; public IntPtr SessionGetOverridableInitializerName; - public IntPtr CreateRunOptions; public IntPtr RunOptionsSetRunLogVerbosityLevel; public IntPtr RunOptionsSetRunLogSeverityLevel; @@ -87,8 +84,8 @@ public struct OrtApi public IntPtr CreateTensorWithDataAsOrtValue; public IntPtr IsTensor; public IntPtr GetTensorMutableData; - public IntPtr FillStringTensor; + public IntPtr GetStringTensorDataLength; public IntPtr GetStringTensorContent; @@ -142,8 +139,6 @@ public struct OrtApi public IntPtr ReleaseTensorTypeAndShapeInfo; public IntPtr ReleaseSessionOptions; public IntPtr ReleaseCustomOpDomain; - // End of Version 1 - DO NOT MODIFY ABOVE (see above text for more information) - public IntPtr GetDenotationFromTypeInfo; public IntPtr CastTypeInfoToMapTypeInfo; public IntPtr CastTypeInfoToSequenceTypeInfo; @@ -153,6 +148,7 @@ public struct OrtApi public IntPtr ReleaseMapTypeInfo; public IntPtr ReleaseSequenceTypeInfo; public IntPtr SessionEndProfiling; + public IntPtr SessionGetModelMetadata; public IntPtr ModelMetadataGetProducerName; public IntPtr ModelMetadataGetGraphName; @@ -161,7 +157,6 @@ public struct OrtApi public IntPtr ModelMetadataLookupCustomMetadataMap; public IntPtr ModelMetadataGetVersion; public IntPtr ReleaseModelMetadata; - // End of Version 2 - DO NOT MODIFY ABOVE (see above text for more information) public IntPtr CreateEnvWithGlobalThreadPools; public IntPtr DisablePerSessionThreads; @@ -169,12 +164,9 @@ public struct OrtApi public IntPtr ReleaseThreadingOptions; public IntPtr ModelMetadataGetCustomMetadataMapKeys; public IntPtr AddFreeDimensionOverrideByName; - // End of Version 3 - DO NOT MODIFY ABOVE (see above text for more information) public IntPtr GetAvailableProviders; public IntPtr ReleaseAvailableProviders; - // End of Version 4 - DO NOT MODIFY ABOVE (see above text for more information) - public IntPtr GetStringTensorElementLength; public IntPtr GetStringTensorElement; public IntPtr FillStringTensorElement; @@ -199,8 +191,6 @@ public struct OrtApi public IntPtr SetGlobalIntraOpNumThreads; public IntPtr SetGlobalInterOpNumThreads; public IntPtr SetGlobalSpinControl; - // End of Version 5 - DO NOT MODIFY ABOVE (see above text for more information) - public IntPtr AddInitializer; public IntPtr CreateEnvWithCustomLoggerAndGlobalThreadPools; public IntPtr SessionOptionsAppendExecutionProvider_CUDA; @@ -209,14 +199,10 @@ public struct OrtApi public IntPtr SetGlobalDenormalAsZero; public IntPtr CreateArenaCfg; public IntPtr ReleaseArenaCfg; - // End of Version 6 - DO NOT MODIFY ABOVE (see above text for more information) - public IntPtr ModelMetadataGetGraphDescription; public IntPtr SessionOptionsAppendExecutionProvider_TensorRT; public IntPtr SetCurrentGpuDeviceId; public IntPtr GetCurrentGpuDeviceId; - // End of Version 7 - DO NOT MODIFY ABOVE (see above text for more information) - public IntPtr KernelInfoGetAttributeArray_float; public IntPtr KernelInfoGetAttributeArray_int64; public IntPtr CreateArenaCfgV2; @@ -225,8 +211,6 @@ public struct OrtApi public IntPtr ReleasePrepackedWeightsContainer; public IntPtr CreateSessionWithPrepackedWeightsContainer; public IntPtr CreateSessionFromArrayWithPrepackedWeightsContainer; - // End of Version 8 - DO NOT MODIFY ABOVE (see above text for more information) - public IntPtr SessionOptionsAppendExecutionProvider_TensorRT_V2; public IntPtr CreateTensorRTProviderOptions; public IntPtr UpdateTensorRTProviderOptions; @@ -249,8 +233,6 @@ public struct OrtApi public IntPtr GetSparseTensorValues; public IntPtr GetSparseTensorIndicesTypeShape; public IntPtr GetSparseTensorIndices; - // End of Version 9 - DO NOT MODIFY ABOVE (see above text for more information) - public IntPtr HasValue; public IntPtr KernelContext_GetGPUComputeStream; public IntPtr GetTensorMemoryInfo; @@ -263,16 +245,12 @@ public struct OrtApi public IntPtr SetGlobalCustomJoinThreadFn; public IntPtr SynchronizeBoundInputs; public IntPtr SynchronizeBoundOutputs; - // End of Version 10 - DO NOT MODIFY ABOVE (see above text for more information) - public IntPtr SessionOptionsAppendExecutionProvider_CUDA_V2; public IntPtr CreateCUDAProviderOptions; public IntPtr UpdateCUDAProviderOptions; public IntPtr GetCUDAProviderOptionsAsString; public IntPtr ReleaseCUDAProviderOptions; public IntPtr SessionOptionsAppendExecutionProvider_MIGraphX; - // End of Version 11 - DO NOT MODIFY ABOVE (see above text for more information) - public IntPtr AddExternalInitializers; public IntPtr CreateOpAttr; public IntPtr ReleaseOpAttr; @@ -282,7 +260,6 @@ public struct OrtApi public IntPtr SessionOptionsAppendExecutionProvider; public IntPtr CopyKernelInfo; public IntPtr ReleaseKernelInfo; - // End of Version 12 - DO NOT MODIFY ABOVE (see above text for more information) public IntPtr GetTrainingApi; public IntPtr SessionOptionsAppendExecutionProvider_CANN; @@ -290,8 +267,6 @@ public struct OrtApi public IntPtr UpdateCANNProviderOptions; public IntPtr GetCANNProviderOptionsAsString; public IntPtr ReleaseCANNProviderOptions; - // End of Version 13 - DO NOT MODIFY ABOVE (see above text for more information) - public IntPtr MemoryInfoGetDeviceType; public IntPtr UpdateEnvWithCustomLogLevel; public IntPtr SetGlobalIntraOpThreadAffinity; @@ -306,8 +281,6 @@ public struct OrtApi public IntPtr KernelInfoGetAttribute_tensor; public IntPtr HasSessionConfigEntry; public IntPtr GetSessionConfigEntry; - // End of Version 14 - DO NOT MODIFY ABOVE (see above text for more information) - public IntPtr SessionOptionsAppendExecutionProvider_Dnnl; public IntPtr CreateDnnlProviderOptions; public IntPtr UpdateDnnlProviderOptions; @@ -324,8 +297,6 @@ public struct OrtApi public IntPtr GetResizedStringTensorElementBuffer; public IntPtr KernelContext_GetAllocator; public IntPtr GetBuildInfoString; - // End of Version 15 - DO NOT MODIFY ABOVE (see above text for more information) - public IntPtr CreateROCMProviderOptions; public IntPtr UpdateROCMProviderOptions; public IntPtr GetROCMProviderOptionsAsString; @@ -337,8 +308,6 @@ public struct OrtApi public IntPtr UpdateCUDAProviderOptionsWithValue; public IntPtr GetCUDAProviderOptionsByName; public IntPtr KernelContext_GetResource; - // End of Version 16 - DO NOT MODIFY ABOVE (see above text for more information) - public IntPtr SetUserLoggingFunction; public IntPtr ShapeInferContext_GetInputCount; public IntPtr ShapeInferContext_GetInputTypeShape; @@ -349,35 +318,25 @@ public struct OrtApi public IntPtr SetDeterministicCompute; public IntPtr KernelContext_ParallelFor; public IntPtr SessionOptionsAppendExecutionProvider_OpenVINO_V2; - // End of Version 17 - DO NOT MODIFY ABOVE (see above text for more information) - public IntPtr SessionOptionsAppendExecutionProvider_VitisAI; public IntPtr KernelContext_GetScratchBuffer; public IntPtr KernelInfoGetAllocator; public IntPtr AddExternalInitializersFromFilesInMemory; - // End of Version 18 - DO NOT MODIFY ABOVE (see above text for more information) - // End of Version 19 - DO NOT MODIFY ABOVE (see above text for more information) - public IntPtr CreateLoraAdapter; public IntPtr CreateLoraAdapterFromArray; public IntPtr ReleaseLoraAdapter; public IntPtr RunOptionsAddActiveLoraAdapter; - public IntPtr SetEpDynamicOptions; - // End of Version 20 - DO NOT MODIFY ABOVE (see above text for more information) - public IntPtr ReleaseValueInfo; public IntPtr ReleaseNode; public IntPtr ReleaseGraph; public IntPtr ReleaseModel; - public IntPtr GetValueInfoName; public IntPtr GetValueInfoTypeInfo; - public IntPtr GetModelEditorApi; - public IntPtr CreateTensorWithDataAndDeleterAsOrtValue; public IntPtr SessionOptionsSetLoadCancellationFlag; + public IntPtr GetCompileApi; public IntPtr CreateKeyValuePairs; @@ -389,7 +348,9 @@ public struct OrtApi public IntPtr RegisterExecutionProviderLibrary; public IntPtr UnregisterExecutionProviderLibrary; + public IntPtr GetEpDevices; + public IntPtr SessionOptionsAppendExecutionProvider_V2; public IntPtr SessionOptionsSetEpSelectionPolicy; public IntPtr SessionOptionsSetEpSelectionPolicyDelegate; @@ -405,95 +366,8 @@ public struct OrtApi public IntPtr EpDevice_EpMetadata; public IntPtr EpDevice_EpOptions; public IntPtr EpDevice_Device; - public IntPtr GetEpApi; - // End of Version 22 - DO NOT MODIFY ABOVE (see above text for more information) - public IntPtr GetTensorSizeInBytes; - public IntPtr AllocatorGetStats; - - public IntPtr CreateMemoryInfo_V2; - public IntPtr MemoryInfoGetDeviceMemType; - public IntPtr MemoryInfoGetVendorId; - - public IntPtr ValueInfo_GetValueProducer; - public IntPtr ValueInfo_GetValueNumConsumers; - public IntPtr ValueInfo_GetValueConsumers; - public IntPtr ValueInfo_GetInitializerValue; - public IntPtr ValueInfo_GetExternalInitializerInfo; - public IntPtr ValueInfo_IsRequiredGraphInput; - public IntPtr ValueInfo_IsOptionalGraphInput; - public IntPtr ValueInfo_IsGraphOutput; - public IntPtr ValueInfo_IsConstantInitializer; - public IntPtr ValueInfo_IsFromOuterScope; - public IntPtr Graph_GetName; - public IntPtr Graph_GetModelPath; - public IntPtr Graph_GetOnnxIRVersion; - public IntPtr Graph_GetNumOperatorSets; - public IntPtr Graph_GetOperatorSets; - public IntPtr Graph_GetNumInputs; - public IntPtr Graph_GetInputs; - public IntPtr Graph_GetNumOutputs; - public IntPtr Graph_GetOutputs; - public IntPtr Graph_GetNumInitializers; - public IntPtr Graph_GetInitializers; - public IntPtr Graph_GetNumNodes; - public IntPtr Graph_GetNodes; - public IntPtr Graph_GetParentNode; - public IntPtr Graph_GetGraphView; - public IntPtr Node_GetId; - public IntPtr Node_GetName; - public IntPtr Node_GetOperatorType; - public IntPtr Node_GetDomain; - public IntPtr Node_GetSinceVersion; - public IntPtr Node_GetNumInputs; - public IntPtr Node_GetInputs; - public IntPtr Node_GetNumOutputs; - public IntPtr Node_GetOutputs; - public IntPtr Node_GetNumImplicitInputs; - public IntPtr Node_GetImplicitInputs; - public IntPtr Node_GetNumAttributes; - public IntPtr Node_GetAttributes; - public IntPtr Node_GetAttributeByName; - public IntPtr OpAttr_GetType; - public IntPtr OpAttr_GetName; - public IntPtr Node_GetNumSubgraphs; - public IntPtr Node_GetSubgraphs; - public IntPtr Node_GetGraph; - public IntPtr Node_GetEpName; - public IntPtr ReleaseExternalInitializerInfo; - public IntPtr ExternalInitializerInfo_GetFilePath; - public IntPtr ExternalInitializerInfo_GetFileOffset; - public IntPtr ExternalInitializerInfo_GetByteSize; - - public IntPtr GetRunConfigEntry; - - public IntPtr EpDevice_MemoryInfo; - - public IntPtr CreateSharedAllocator; - public IntPtr GetSharedAllocator; - public IntPtr ReleaseSharedAllocator; - - public IntPtr GetTensorData; - - public IntPtr GetSessionOptionsConfigEntries; - - public IntPtr SessionGetMemoryInfoForInputs; - public IntPtr SessionGetMemoryInfoForOutputs; - public IntPtr SessionGetEpDeviceForInputs; - - public IntPtr CreateSyncStreamForEpDevice; - public IntPtr SyncStream_GetHandle; - public IntPtr ReleaseSyncStream; - - public IntPtr CopyTensors; - - public IntPtr CreateMIGraphXProviderOptions; - public IntPtr UpdateMIGraphXProviderOptions; - public IntPtr GetMIGraphXProviderOptionsAsString; - public IntPtr ReleaseMIGraphXProviderOptions; - public IntPtr UpdateMIGraphXProviderOptionsWithValue; - public IntPtr GetMIGraphXProviderOptionsByName; } internal static class NativeMethods @@ -737,18 +611,6 @@ static NativeMethods() OrtUpdateROCMProviderOptions = (DOrtUpdateROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.UpdateROCMProviderOptions, typeof(DOrtUpdateROCMProviderOptions)); OrtGetROCMProviderOptionsAsString = (DOrtGetROCMProviderOptionsAsString)Marshal.GetDelegateForFunctionPointer(api_.GetROCMProviderOptionsAsString, typeof(DOrtGetROCMProviderOptionsAsString)); OrtReleaseROCMProviderOptions = (DOrtReleaseROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseROCMProviderOptions, typeof(DOrtReleaseROCMProviderOptions)); - SessionOptionsAppendExecutionProvider_MIGraphX = (DSessionOptionsAppendExecutionProvider_MIGraphX)Marshal.GetDelegateForFunctionPointer( - api_.SessionOptionsAppendExecutionProvider_MIGraphX, typeof(DSessionOptionsAppendExecutionProvider_MIGraphX)); - OrtCreateMIGraphXProviderOptions = (DOrtCreateMIGraphXProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.CreateMIGraphXProviderOptions, typeof(DOrtCreateMIGraphXProviderOptions)); - OrtUpdateMIGraphXProviderOptions = (DOrtUpdateMIGraphXProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.UpdateMIGraphXProviderOptions, typeof(DOrtUpdateMIGraphXProviderOptions)); - OrtGetMIGraphXProviderOptionsAsString = (DOrtGetMIGraphXProviderOptionsAsString)Marshal.GetDelegateForFunctionPointer(api_.GetMIGraphXProviderOptionsAsString, typeof(DOrtGetMIGraphXProviderOptionsAsString)); - OrtReleaseMIGraphXProviderOptions = (DOrtReleaseMIGraphXProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseMIGraphXProviderOptions, typeof(DOrtReleaseMIGraphXProviderOptions)); - OrtUpdateMIGraphXProviderOptionsWithValue = - (DOrtUpdateMIGraphXProviderOptionsWithValue)Marshal.GetDelegateForFunctionPointer( - api_.UpdateMIGraphXProviderOptionsWithValue, typeof(DOrtUpdateMIGraphXProviderOptionsWithValue)); - OrtGetMIGraphXProviderOptionsByName = - (DOrtGetMIGraphXProviderOptionsByName)Marshal.GetDelegateForFunctionPointer( - api_.GetMIGraphXProviderOptionsByName, typeof(DOrtGetMIGraphXProviderOptionsByName)); OrtCreateAndRegisterAllocatorV2 = (DCreateAndRegisterAllocatorV2)Marshal.GetDelegateForFunctionPointer(api_.CreateAndRegisterAllocatorV2, typeof(DCreateAndRegisterAllocatorV2)); OrtRunAsync = (DOrtRunAsync)Marshal.GetDelegateForFunctionPointer(api_.RunAsync, typeof(DOrtRunAsync)); CreateLoraAdapter = (DCreateLoraAdapter)Marshal.GetDelegateForFunctionPointer(api_.CreateLoraAdapter, @@ -1059,80 +921,6 @@ internal class NativeLib #endregion -#region Provider Options API - /// - /// Creates native OrtMIGraphXProviderOptions instance - /// - /// (output) native instance of OrtMIGraphXProviderOptions - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */ DOrtCreateMIGraphXProviderOptions( - out IntPtr /*(OrtMIGraphXProviderOptions**)*/ migraphxProviderOptionsInstance); - public static DOrtCreateMIGraphXProviderOptions OrtCreateMIGraphXProviderOptions; - - /// - /// Updates native OrtMIGraphXProviderOptions instance using given key/value pairs - /// - /// native instance of OrtMIGraphXProviderOptions - /// configuration keys of OrtMIGraphXProviderOptions - /// configuration values of OrtMIGraphXProviderOptions - /// number of configuration keys - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */ DOrtUpdateMIGraphXProviderOptions( - IntPtr /*(OrtMIGraphXProviderOptions*)*/ migraphxProviderOptionsInstance, - IntPtr[] /*(const char* const *)*/ providerOptionsKeys, - IntPtr[] /*(const char* const *)*/ providerOptionsValues, - UIntPtr /*(size_t)*/ numKeys); - public static DOrtUpdateMIGraphXProviderOptions OrtUpdateMIGraphXProviderOptions; - - /// - /// Get native OrtMIGraphXProviderOptions in serialized string - /// - /// instance of OrtAllocator - /// is a UTF-8 null terminated string allocated using 'allocator' - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */ DOrtGetMIGraphXProviderOptionsAsString( - IntPtr /*(OrtMIGraphXProviderOptions**)*/ migraphxProviderOptionsInstance, - IntPtr /*(OrtAllocator*)*/ allocator, - out IntPtr /*(char**)*/ ptr); - public static DOrtGetMIGraphXProviderOptionsAsString OrtGetMIGraphXProviderOptionsAsString; - - /// - /// Releases native OrtMIGraphXProviderOptions instance - /// - /// native instance of OrtMIGraphXProviderOptions to be released - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate void DOrtReleaseMIGraphXProviderOptions(IntPtr /*(OrtMIGraphXProviderOptions*)*/ migraphxProviderOptionsInstance); - public static DOrtReleaseMIGraphXProviderOptions OrtReleaseMIGraphXProviderOptions; - - /// - /// Update native OrtMIGraphXProviderOptions with value - /// - /// native instance of OrtMIGraphXProviderOptions to be released - /// configuration key of OrtMIGraphXProviderOptions - /// configuration value of OrtMIGraphXProviderOptions - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr DOrtUpdateMIGraphXProviderOptionsWithValue( - IntPtr /*(OrtMIGraphXProviderOptions**)*/ migraphxProviderOptionsInstance, - IntPtr /*(char*)*/ providerOptionsKey, - IntPtr /*(char*)*/ providerOptionsValue); - public static DOrtUpdateMIGraphXProviderOptionsWithValue OrtUpdateMIGraphXProviderOptionsWithValue; - - /// - /// Get native OrtMIGraphXProviderOptions value by name - /// - /// native instance of OrtMIGraphXProviderOptions to be released - /// configuration key of OrtMIGraphXProviderOptions - /// configuration value of OrtMIGraphXProviderOptions to return - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr DOrtGetMIGraphXProviderOptionsByName( - IntPtr /*(OrtMIGraphXProviderOptions**)*/ migraphxProviderOptionsInstance, - IntPtr /*(char*)*/ providerOptionsKey, - out IntPtr /*(char**)*/ providerOptionsValue); - public static DOrtGetMIGraphXProviderOptionsByName OrtGetMIGraphXProviderOptionsByName; - - -#endregion - #region Status API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate ErrorCode DOrtGetErrorCode(IntPtr /*(OrtStatus*)*/ status); @@ -1499,9 +1287,6 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)] public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_MIGraphX(IntPtr /*(OrtSessionOptions*)*/ options, int device_id); - - [DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)] - public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_MIGraphX(IntPtr /*(OrtSessionOptions*)*/ options, int use_arena, int device_id); #endif /// /// Append a TensorRT EP instance (configured based on given provider options) to the native OrtSessionOptions instance @@ -1563,18 +1348,6 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DSessionOptionsAppendExecutionProvider_ROCM SessionOptionsAppendExecutionProvider_ROCM; - /// - /// Append a MIGraphX EP instance (configured based on given provider options) to the native OrtSessionOptions instance - /// - /// Native OrtSessionOptions instance - /// Native OrtMIGraphXProviderOptions instance - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DSessionOptionsAppendExecutionProvider_MIGraphX( - IntPtr /*(OrtSessionOptions*)*/ options, - IntPtr /*(const OrtMIGraphXProviderOptions*)*/ migraphxProviderOptions); - - public static DSessionOptionsAppendExecutionProvider_MIGraphX SessionOptionsAppendExecutionProvider_MIGraphX; - /// /// Free Dimension override (by denotation) /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs index 335b4ef8b3f65..1b9cd7572170b 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs @@ -291,142 +291,6 @@ protected override bool ReleaseHandle() } -/// - /// Holds the options for configuring an MIGraphX Execution Provider instance - /// - public class OrtMIGraphXProviderOptions : SafeHandle - { - internal IntPtr Handle - { - get - { - return handle; - } - } - - public int DeviceId - { - get { return _deviceId; } - set - { - UpdateProviderOptionWithValue(_deviceIdPtr, value.ToString()); - _deviceId = value; - } - } - private IntPtr _deviceIdPtr = Marshal.StringToHGlobalAnsi("device_id"); - private int _deviceId = 0; - - public string ModelCacheDir - { - get { return _modelCacheDir; } - set - { - UpdateProviderOptionWithValue(_modelCacheDirPtr, value); - _modelCacheDir = value; - } - } - - private IntPtr _modelCacheDirPtr = Marshal.StringToHGlobalAnsi("migraphx_model_cache_dir"); - private string _modelCacheDir = ""; - - #region Constructor - - /// - /// Constructs an empty OrtMIGraphXProviderOptions instance - /// - public OrtMIGraphXProviderOptions() : base(IntPtr.Zero, true) - { - NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateMIGraphXProviderOptions(out handle)); - } - - #endregion - - #region Finalizer - - ~OrtMIGraphXProviderOptions() - { - Marshal.FreeHGlobal(_deviceIdPtr); - Marshal.FreeHGlobal(_modelCacheDirPtr); - } - - #endregion - - #region Public Methods - - /// - /// Get MIGraphX EP provider options - /// - /// return C# UTF-16 encoded string - public string GetOptions() - { - var allocator = OrtAllocator.DefaultInstance; - // Process provider options string - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetMIGraphXProviderOptionsAsString(handle, - allocator.Pointer, out IntPtr providerOptions)); - return NativeOnnxValueHelper.StringFromNativeUtf8(providerOptions, allocator); - } - - /// - /// Updates the configuration knobs of OrtMIGraphXProviderOptions that will eventually be used to configure a MIGraphX EP - /// - /// Array of keys to set that correspond with values. - /// Array of values to set that correspond with keys. - /// The number of key/value pairs in the arrays. - private static IntPtr UpdateMIGraphXProviderOptions(IntPtr handle, IntPtr[] keys, IntPtr[] values, UIntPtr count) - { - return NativeMethods.OrtUpdateMIGraphXProviderOptions(handle, keys, values, count); - } - - /// - /// Updates the configuration knobs of OrtMIGraphXProviderOptions that will eventually be used to configure a MIGraphX EP - /// - /// key/value pairs used to configure a MIGraphX Execution Provider - public void UpdateOptions(Dictionary providerOptions) - { - ProviderOptionsUpdater.Update(providerOptions, handle, UpdateMIGraphXProviderOptions); - } - - #endregion - - #region Public Properties - - /// - /// Overrides SafeHandle.IsInvalid - /// - /// returns true if handle is equal to Zero - public override bool IsInvalid { get { return handle == IntPtr.Zero; } } - - #endregion - - #region Private Methods - - private void UpdateProviderOptionWithValue(IntPtr key, string value) - { - IntPtr valuePtr = Marshal.StringToHGlobalAnsi(value); - var nativeStatus = NativeMethods.OrtUpdateMIGraphXProviderOptionsWithValue(handle, key, valuePtr); - Marshal.FreeHGlobal(valuePtr); - NativeApiStatus.VerifySuccess(nativeStatus); - } - - #endregion - - #region SafeHandle - /// - /// Overrides SafeHandle.ReleaseHandle() to properly dispose of - /// the native instance of OrtMIGraphXProviderOptions - /// - /// always returns true - protected override bool ReleaseHandle() - { - NativeMethods.OrtReleaseMIGraphXProviderOptions(handle); - handle = IntPtr.Zero; - return true; - } - - #endregion - } - - /// /// This helper class contains methods to handle values of provider options /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index c85cd64efeec0..6e325f7fe9646 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -58,9 +58,6 @@ public class SessionOptions : SafeHandle private static string[] cudaDelayLoadedLibs = { }; private static string[] trtDelayLoadedLibs = { }; - // Delay-loaded MIGraphX DLLs. Currently, delayload is disabled. See cmake/CMakeLists.txt for more information. - private static string[] migxDelayLoadedLibs = { }; - #region Constructor and Factory methods /// @@ -208,28 +205,6 @@ public static SessionOptions MakeSessionOptionWithRocmProvider(OrtROCMProviderOp throw; } } - - /// - /// A helper method to construct a SessionOptions object for MIGraaphX execution provider. - /// Use only if MIGraphX is installed and you have the onnxruntime package specific to this Execution Provider. - /// - /// MIGraphX EP provider options - /// A SessionsOptions() object configured for execution on provider options - public static SessionOptions MakeSessionOptionWithMIGraphXProvider(OrtMIGraphXProviderOptions migxProviderOptions) - { - CheckMIGraphXExecutionProviderDLLs(); - SessionOptions options = new SessionOptions(); - try - { - options.AppendExecutionProvider_MIGraphX(migxProviderOptions); - return options; - } - catch (Exception) - { - options.Dispose(); - throw; - } - } #endregion #region ExecutionProviderAppends @@ -372,25 +347,12 @@ public void AppendExecutionProvider_ROCm(OrtROCMProviderOptions rocmProviderOpti public void AppendExecutionProvider_MIGraphX(int deviceId = 0) { #if __MOBILE__ - throw new NotSupportedException("The MIGraphX Execution Provider is not supported in this build"); + throw new NotSupportedException($"The MIGraphX Execution Provider is not supported in this build"); #else NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_MIGraphX(handle, deviceId)); #endif } - /// - /// Use only if you have the onnxruntime package specific to this Execution Provider. - /// - /// device identification - public void AppendExecutionProvider_MIGraphX(OrtMIGraphXProviderOptions migraphxProviderOptions) - { -#if __MOBILE__ - throw new NotSupportedException($"The AMD Nitris Execution Provider is not supported in this build"); -#else - NativeApiStatus.VerifySuccess(NativeMethods.SessionOptionsAppendExecutionProvider_MIGraphX(handle, migraphxProviderOptions.Handle)); -#endif - } - /// /// Use only if you have the onnxruntime package specific to this Execution Provider. /// @@ -1163,27 +1125,6 @@ private static bool CheckRocmExecutionProviderDLLs() return true; } - private static bool CheckMIGraphXExecutionProviderDLLs() - { - if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - foreach (var dll in migxDelayLoadedLibs) - { - IntPtr handle = LoadLibrary(dll); - if (handle != IntPtr.Zero) - continue; - var sysdir = new StringBuilder(String.Empty, 2048); - GetSystemDirectory(sysdir, (uint)sysdir.Capacity); - throw new OnnxRuntimeException( - ErrorCode.NoSuchFile, - $"kernel32.LoadLibrary():'{dll}' not found. MIGraphX are required for GPU execution. " + - $". Verify it is available in the system directory={sysdir}. Else copy it to the output folder." - ); - } - } - return true; - } - #endregion #region SafeHandle diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 71f1b44d37697..4b4423b1465d9 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6472,88 +6472,6 @@ struct OrtApi { _In_reads_(num_tensors) OrtValue* const* dst_tensors, _In_opt_ OrtSyncStream* stream, _In_ size_t num_tensors); - - /// @} - /// \name OrtMIGraphXProviderOptions - /// @{ - - /** \brief Create an OrtMIGraphXProviderOptions - * - * \param[out] out Newly created ::OrtMIGraphXProviderOptions. Must be released with OrtApi::ReleaseMIGraphXProviderOptions - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.xx. - */ - ORT_API2_STATUS(CreateMIGraphXProviderOptions, _Outptr_ OrtMIGraphXProviderOptions** out); - - /** \brief Set options in a MIGraphX Execution Provider. - * - * For example, key="device_id" and value="0" - * - * \param[in] migraphx_options - * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys - * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values - * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.xx. - */ - ORT_API2_STATUS(UpdateMIGraphXProviderOptions, _Inout_ OrtMIGraphXProviderOptions* migraphx_options, - _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, - _In_ size_t num_keys); - - /** - * Get serialized MIGraphX provider options string. - * - * For example, "device_id=0;;......" - * - * \param migraphx_options - OrtMIGraphXProviderOptions instance - * \param allocator - a ptr to an instance of OrtAllocator obtained with CreateAllocator() or GetAllocatorWithDefaultOptions() - * the specified allocator will be used to allocate continuous buffers for output strings and lengths. - * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.xx. - */ - ORT_API2_STATUS(GetMIGraphXProviderOptionsAsString, _In_ const OrtMIGraphXProviderOptions* migraphx_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); - - /** \brief Release an ::OrtMIGraphXProviderOptions - * - * \note This is an exception in the naming convention of other Release* functions, as the name of the method does not have the V2 suffix, but the type does - * - * \since Version 1.xx. - */ - void(ORT_API_CALL* ReleaseMIGraphXProviderOptions)(_Frees_ptr_opt_ OrtMIGraphXProviderOptions* input); - - /** - * Update MIGraphX EP provider option where its data type is pointer, for example 'user_compute_stream'. - * If the data type of the provider option can be represented by string please use UpdateMIGraphXProviderOptions. - * - * Note: It's caller's responsibility to properly manage the lifetime of the instance pointed by this pointer. - * - * \param migraphx_options - OrtMIGraphXProviderOptions instance - * \param key - Name of the provider option - * \param value - A pointer to the instance that will be assigned to this provider option - * - * \since Version 1.xx. - */ - ORT_API2_STATUS(UpdateMIGraphXProviderOptionsWithValue, _Inout_ OrtMIGraphXProviderOptions* migraphx_options, _In_ const char* key, _In_ void* value); - - /** - * Get MIGraphX EP provider option where its data type is pointer. - * If the data type of the provider option can be represented by string please use GetMIGraphXProviderOptionsAsString. - * - * \param migraphx_options - OrtMIGraphXProviderOptions instance - * \param key - Name of the provider option - * \param ptr - A pointer to the instance that is kept by the provider option - * - * \since Version 1.xx. - */ - ORT_API2_STATUS(GetMIGraphXProviderOptionsByName, _In_ const OrtMIGraphXProviderOptions* migraphx_options, _In_ const char* key, _Outptr_ void** ptr); }; /* diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.cc b/onnxruntime/core/providers/migraphx/migraphx_call.cc index 997f318fc640c..79dfb5512d3b5 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_call.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_call.cc @@ -24,8 +24,8 @@ std::string_view RocmErrString(ERRTYPE x) { } #define CASE_ENUM_TO_STR(x) \ -case x: \ -return #x + case x: \ + return #x template <> std::string_view RocmErrString(hipError_t x) { diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 710a215c25f00..41d5dc6ed37b6 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -198,7 +198,6 @@ struct MIGraphX_Provider : Provider { return Status::OK(); } - void Initialize() override { #ifdef _WIN32 HMODULE module = nullptr; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 40bb721985acd..88d84e95b406c 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4095,13 +4095,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::ReleaseSyncStream, &OrtApis::CopyTensors, - - &OrtApis::CreateMIGraphXProviderOptions, - &OrtApis::UpdateMIGraphXProviderOptions, - &OrtApis::GetMIGraphXProviderOptionsAsString, - &OrtApis::ReleaseMIGraphXProviderOptions, - &OrtApis::UpdateMIGraphXProviderOptionsWithValue, - &OrtApis::GetMIGraphXProviderOptionsByName}; +}; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. static_assert(sizeof(OrtApiBase) == sizeof(void*) * 2, "New methods can't be added to OrtApiBase as it is not versioned"); diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 2a2530ebcf054..3eee174ff81f4 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -740,15 +740,4 @@ ORT_API_STATUS_IMPL(CopyTensors, _In_ const OrtEnv* env, _In_reads_(num_tensors) OrtValue* const* dst_tensors, _In_opt_ OrtSyncStream* stream, _In_ size_t num_tensors); - -ORT_API_STATUS_IMPL(CreateMIGraphXProviderOptions, _Outptr_ OrtMIGraphXProviderOptions** out); -ORT_API_STATUS_IMPL(UpdateMIGraphXProviderOptions, _Inout_ OrtMIGraphXProviderOptions* migraphx_options, - _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, - size_t num_keys); -ORT_API_STATUS_IMPL(GetMIGraphXProviderOptionsAsString, _In_ const OrtMIGraphXProviderOptions* migraphx_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); -ORT_API(void, ReleaseMIGraphXProviderOptions, _Frees_ptr_opt_ OrtMIGraphXProviderOptions*); - -ORT_API_STATUS_IMPL(UpdateMIGraphXProviderOptionsWithValue, _Inout_ OrtMIGraphXProviderOptions* migraphx_options, _In_ const char* key, _In_ void* value); -ORT_API_STATUS_IMPL(GetMIGraphXProviderOptionsByName, _In_ const OrtMIGraphXProviderOptions* migraphx_options, _In_ const char* key, _Outptr_ void** ptr); } // namespace OrtApis diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index e1160291269dd..ddf9241795cb8 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -6,7 +6,6 @@ #include #include -#include #include "core/common/inlined_containers.h" #include "core/common/path_string.h" @@ -109,7 +108,6 @@ using EtwRegistrationManager_EtwInternalCallback = EtwRegistrationManager::EtwIn #include "core/providers/cann/cann_provider_factory.h" #include "core/providers/dnnl/dnnl_provider_factory.h" #include "core/providers/migraphx/migraphx_provider_factory.h" -#include "core/providers/migraphx/migraphx_execution_provider_info.h" #include "core/providers/openvino/openvino_provider_factory.h" #include "core/providers/tensorrt/tensorrt_provider_factory.h" #include "core/providers/tensorrt/tensorrt_provider_options.h" @@ -2650,8 +2648,7 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateTensorRTProviderOptions, #if defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE) || \ defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) || \ defined(USE_CANN) || \ - defined(USE_DNNL) || \ - defined(USE_MIGRAPHX) + defined(USE_DNNL) static std::string BuildOptionsString(const onnxruntime::ProviderOptions::iterator& begin, const onnxruntime::ProviderOptions::iterator& end) { std::ostringstream options; @@ -3086,125 +3083,3 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, _In_ return nullptr; API_IMPL_END } - -ORT_API_STATUS_IMPL(OrtApis::CreateMIGraphXProviderOptions, _Outptr_ OrtMIGraphXProviderOptions** out) { - API_IMPL_BEGIN -#ifdef USE_MIGRAPHX - auto migraphx_options = std::make_unique(); - memset(migraphx_options.get(), 0, sizeof(OrtMIGraphXProviderOptions)); - *out = migraphx_options.release(); - return nullptr; -#else - ORT_UNUSED_PARAMETER(out); - return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build."); -#endif - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::UpdateMIGraphXProviderOptions, - _Inout_ OrtMIGraphXProviderOptions* migraphx_options, - _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, - size_t num_keys) { - API_IMPL_BEGIN -#ifdef USE_MIGRAPHX - onnxruntime::ProviderOptions provider_options_map; - for (size_t i = 0; i != num_keys; ++i) { - if (provider_options_keys[i] == nullptr || provider_options_keys[i][0] == '\0' || - provider_options_values[i] == nullptr || provider_options_values[i][0] == '\0') { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "key/value cannot be empty"); - } - - provider_options_map[provider_options_keys[i]] = provider_options_values[i]; - } - - onnxruntime::s_library_migraphx.Get().UpdateProviderOptions(reinterpret_cast(migraphx_options), provider_options_map); - return nullptr; -#else - ORT_UNUSED_PARAMETER(migraphx_options); - ORT_UNUSED_PARAMETER(provider_options_keys); - ORT_UNUSED_PARAMETER(provider_options_values); - ORT_UNUSED_PARAMETER(num_keys); - return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build."); -#endif - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::GetMIGraphXProviderOptionsAsString, - _In_ const OrtMIGraphXProviderOptions* migraphx_options, - _Inout_ OrtAllocator* allocator, - _Outptr_ char** ptr) { - API_IMPL_BEGIN -#ifdef USE_MIGRAPHX - onnxruntime::ProviderOptions options = - onnxruntime::s_library_migraphx.Get().GetProviderOptions(reinterpret_cast(migraphx_options)); - std::string options_str = BuildOptionsString(options.begin(), options.end()); - *ptr = onnxruntime::StrDup(options_str, allocator); - return nullptr; -#else - ORT_UNUSED_PARAMETER(migraphx_options); - ORT_UNUSED_PARAMETER(allocator); - ORT_UNUSED_PARAMETER(ptr); - return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build."); -#endif - API_IMPL_END -} - -ORT_API(void, OrtApis::ReleaseMIGraphXProviderOptions, _Frees_ptr_opt_ OrtMIGraphXProviderOptions* ptr) { -#ifdef USE_MIGRAPHX - std::unique_ptr p(ptr); - OrtAllocator* allocator; - GetAllocatorWithDefaultOptions(&allocator); - if (ptr->migraphx_cache_dir != nullptr) { - allocator->Free(allocator, const_cast(ptr->migraphx_cache_dir)); - } -#else - ORT_UNUSED_PARAMETER(ptr); -#endif -} - -ORT_API_STATUS_IMPL(OrtApis::UpdateMIGraphXProviderOptionsWithValue, - _Inout_ OrtMIGraphXProviderOptions* migraphx_options, - _In_ const char* key, - _In_ void* value) { - API_IMPL_BEGIN -#ifdef USE_MIGRAPHX - auto sv = std::string_view{key}; - OrtAllocator* allocator; - GetAllocatorWithDefaultOptions(&allocator); - if (sv == onnxruntime::migraphx_provider_option::kDeviceId) { - auto dv = std::string_view{static_cast(value)}; - if (std::from_chars(dv.data(), dv.data() + dv.length(), migraphx_options->device_id).ec == std::errc::invalid_argument) { - ORT_THROW("Cannot convert from string to integer - invalid argument"); - } - } else if (sv == onnxruntime::migraphx_provider_option::kModelCacheDir) { - auto sd = std::string_view{static_cast(value)}; - migraphx_options->migraphx_cache_dir = onnxruntime::StrDup(onnxruntime::ToPathString(sd), allocator); - } else { - ORT_THROW("Unsupported provider option name: '" + std::string{sv} + "'"); - } - return nullptr; -#else - ORT_UNUSED_PARAMETER(migraphx_options); - ORT_UNUSED_PARAMETER(key); - ORT_UNUSED_PARAMETER(value); - return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build."); -#endif - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::GetMIGraphXProviderOptionsByName, - _In_ const OrtMIGraphXProviderOptions* migraphx_options, - _In_ const char* key, - _Outptr_ void** ptr) { - API_IMPL_BEGIN -#ifdef USE_MIGRAPHX - return nullptr; -#else - ORT_UNUSED_PARAMETER(migraphx_options); - ORT_UNUSED_PARAMETER(key); - ORT_UNUSED_PARAMETER(ptr); - return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build."); -#endif - API_IMPL_END -} diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index a99f873ef0eb9..48d52ae3cf428 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -627,56 +627,6 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, ORT_UNUSED_PARAMETER(num_keys); return CreateNotEnabledStatus("VitisAI"); } - -ORT_API_STATUS_IMPL(OrtApis::CreateMIGraphXProviderOptions, _Outptr_ OrtMIGraphXProviderOptions** out) { - ORT_UNUSED_PARAMETER(out); - return CreateNotEnabledStatus("MIGraphX"); -} - -ORT_API_STATUS_IMPL(OrtApis::UpdateMIGraphXProviderOptions, - _Inout_ OrtMIGraphXProviderOptions* migraphx_options, - _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, - size_t num_keys) { - ORT_UNUSED_PARAMETER(migraphx_options); - ORT_UNUSED_PARAMETER(provider_options_keys); - ORT_UNUSED_PARAMETER(provider_options_values); - ORT_UNUSED_PARAMETER(num_keys); - return CreateNotEnabledStatus("MIGraphX"); -} - -ORT_API_STATUS_IMPL(OrtApis::GetMIGraphXProviderOptionsAsString, - _In_ const OrtMIGraphXProviderOptions* migraphx_options, _Inout_ OrtAllocator* allocator, - _Outptr_ char** ptr) { - ORT_UNUSED_PARAMETER(migraphx_options); - ORT_UNUSED_PARAMETER(allocator); - ORT_UNUSED_PARAMETER(ptr); - return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build."); -} - -ORT_API(void, OrtApis::ReleaseMIGraphXProviderOptions, _Frees_ptr_opt_ OrtMIGraphXProviderOptions* ptr) { - ORT_UNUSED_PARAMETER(ptr); -} - -ORT_API_STATUS_IMPL(OrtApis::UpdateMIGraphXProviderOptionsWithValue, - _Inout_ OrtMIGraphXProviderOptions* migraphx_options, - _In_ const char* key, - _In_ void* value) { - ORT_UNUSED_PARAMETER(migraphx_options); - ORT_UNUSED_PARAMETER(key); - ORT_UNUSED_PARAMETER(value); - return CreateNotEnabledStatus("MIGraphX"); -} - -ORT_API_STATUS_IMPL(OrtApis::GetMIGraphXProviderOptionsByName, - _In_ const OrtMIGraphXProviderOptions* migraphx_options, - _In_ const char* key, - _Outptr_ void** ptr) { - ORT_UNUSED_PARAMETER(migraphx_options); - ORT_UNUSED_PARAMETER(key); - ORT_UNUSED_PARAMETER(ptr); - return CreateNotEnabledStatus("MIGraphX"); -} #endif ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_ROCM, _In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* provider_options) { diff --git a/setup.py b/setup.py index 03e8ef671f149..6bfb53329f319 100644 --- a/setup.py +++ b/setup.py @@ -282,6 +282,8 @@ def run(self): self._rewrite_ld_preload_tensorrt(to_preload_tensorrt) self._rewrite_ld_preload_tensorrt(to_preload_nv_tensorrt_rtx) self._rewrite_ld_preload(to_preload_cann) + else: + pass _bdist_wheel.run(self) if is_manylinux and not disable_auditwheel_repair and not is_openvino and not is_qnn: diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index c9bec9a3839a2..085f7024298c4 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1999,7 +1999,6 @@ def build_nuget_package( use_winml, use_qnn, use_dml, - use_migraphx, enable_training_apis, msbuild_extra_options, ): @@ -2037,9 +2036,6 @@ def build_nuget_package( elif use_tensorrt: execution_provider = "/p:ExecutionProvider=tensorrt" package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.TensorRT" - elif use_migraphx: - execution_provider = "/p:ExecutionProvider=migraphx" - package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.MIGraphX" elif use_dnnl: execution_provider = "/p:ExecutionProvider=dnnl" package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.DNNL" @@ -2631,7 +2627,6 @@ def main(): getattr(args, "use_winml", False), args.use_qnn, getattr(args, "use_dml", False), - args.use_migraphx, args.enable_training_apis, normalize_arg_list(args.msbuild_extra_options), ) diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index ead240a7cef1b..211cb7a2a8a75 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -22,8 +22,6 @@ def get_package_name(os, cpu_arch, ep, is_training_package): pkg_name += "-tensorrt" elif ep == "rocm": pkg_name += "-rocm" - elif ep == "migraphx": - pkg_name += "-migraphx" elif os == "linux": pkg_name += "-linux-" pkg_name += cpu_arch @@ -33,8 +31,6 @@ def get_package_name(os, cpu_arch, ep, is_training_package): pkg_name += "-tensorrt" elif ep == "rocm": pkg_name += "-rocm" - elif ep == "migraphx": - pkg_name += "-migraphx" elif os == "osx": pkg_name = "onnxruntime-osx-" + cpu_arch return pkg_name @@ -48,11 +44,7 @@ def get_package_name(os, cpu_arch, ep, is_training_package): def is_this_file_needed(ep, filename, package_name): if package_name == "Microsoft.ML.OnnxRuntime.Gpu": return False - return ( - (ep != "cuda" or "cuda" in filename) - and (ep != "tensorrt" or "cuda" not in filename) - and (ep != "migraphx" or "migraphx" not in filename) - ) + return (ep != "cuda" or "cuda" in filename) and (ep != "tensorrt" or "cuda" not in filename) # nuget_artifacts_dir: the directory with uncompressed C API tarball/zip files @@ -146,7 +138,7 @@ def parse_arguments(): required=False, default="None", type=str, - choices=["cuda", "dnnl", "openvino", "migraphx", "tensorrt", "snpe", "qnn", "None"], + choices=["cuda", "dnnl", "openvino", "tensorrt", "snpe", "qnn", "None"], help="The selected execution provider for this build.", ) parser.add_argument("--sdk_info", required=False, default="", type=str, help="dependency SDK information.") @@ -190,8 +182,6 @@ def generate_description(line_list, package_name): description = "This package contains Linux native shared library artifacts for ONNX Runtime with CUDA." elif "Microsoft.ML.OnnxRuntime.Gpu.Windows" in package_name: description = "This package contains Windows native shared library artifacts for ONNX Runtime with CUDA." - elif "Microsoft.ML.OnnxRuntime.MIGraphX" in package_name: - description = "This package contains native shared library artifacts for ONNX Runtime with MIGraphX." elif "Intel.ML.OnnxRuntime" in package_name: description = "This package contains native shared library artifacts for ONNX Runtime with OpenVINO." elif "Microsoft.ML.OnnxRuntime" in package_name: # This is a Microsoft.ML.OnnxRuntime.* package @@ -369,7 +359,6 @@ def generate_files(line_list, args): is_windowsai_package = args.package_name == "Microsoft.AI.MachineLearning" is_snpe_package = args.package_name == "Microsoft.ML.OnnxRuntime.Snpe" is_qnn_package = args.package_name == "Microsoft.ML.OnnxRuntime.QNN" - is_migraphx_package = args.package_name == "Microsoft.ML.OnnxRuntime.MIGraphX" is_training_package = args.package_name in [ "Microsoft.ML.OnnxRuntime.Training", "Microsoft.ML.OnnxRuntime.Training.Gpu", @@ -395,24 +384,6 @@ def generate_files(line_list, args): "openvino_ep_shared_lib": "onnxruntime_providers_openvino.dll", "cuda_ep_shared_lib": "onnxruntime_providers_cuda.dll", "qnn_ep_shared_lib": "onnxruntime_providers_qnn.dll", - "migraphx_ep_shared_lib": "onnxruntime_providers_migraphx.dll", - "amd_comgr0602": "amd_comgr0602.dll", - "amd_comgr0604": "amd_comgr0604.dll", - "amd_comgr0700": "amd_comgr0700.dll", - "hiprtc0602": "hiprtc0602.dll", - "hiprtc0604": "hiprtc0604.dll", - "hiprtc0700": "hiprtc0700.dll", - "hiprtc-builtins0602": "hiprtc-builtins0602.dll", - "hiprtc-builtins0604": "hiprtc-builtins0604.dll", - "hiprtc-builtins0700": "hiprtc-builtins0700.dll", - "migraphx-hiprtc-driver": "migraphx-hiprtc-driver.exe", - "migraphx": "migraphx.dll", - "migraphx_c": "migraphx_c.dll", - "migraphx_cpu": "migraphx_cpu.dll", - "migraphx_device": "migraphx_device.dll", - "migraphx_gpu": "migraphx_gpu.dll", - "migraphx_onnx": "migraphx_onnx.dll", - "migraphx_tf": "migraphx_tf", "onnxruntime_perf_test": "onnxruntime_perf_test.exe", "onnx_test_runner": "onnx_test_runner.exe", } @@ -431,7 +402,6 @@ def generate_files(line_list, args): "openvino_ep_shared_lib": "libonnxruntime_providers_openvino.so", "cuda_ep_shared_lib": "libonnxruntime_providers_cuda.so", "rocm_ep_shared_lib": "libonnxruntime_providers_rocm.so", - "migraphx_ep_shared_lib": "libonnxruntime_providers_migraphx.so", "onnxruntime_perf_test": "onnxruntime_perf_test", "onnx_test_runner": "onnx_test_runner", } @@ -451,7 +421,7 @@ def generate_files(line_list, args): include_dir = f"{build_dir}\\native\\include" # Sub.Gpu packages do not include the onnxruntime headers - if args.package_name != "Microsoft.ML.OnnxRuntime.Gpu" and args.package_name != "Microsoft.ML.OnnxRuntime.MIGraphX": + if args.package_name != "Microsoft.ML.OnnxRuntime.Gpu": files_list.append( "' ) - if args.execution_provider == "migraphx": - files_list.append( - "' - ) - files_list.append( - "' - ) - - if is_windows_build: - native_build_path = Path(args.native_build_path) - - def _files_list_append(key: str): - path = native_build_path / nuget_dependencies[key] - if path.exists(): - files_list.append( - "' - ) - - _files_list_append("amd_comgr0602") - _files_list_append("amd_comgr0604") - _files_list_append("amd_comgr0700") - _files_list_append("hiprtc0602") - _files_list_append("hiprtc0604") - _files_list_append("hiprtc0700") - _files_list_append("hiprtc-builtins0602") - _files_list_append("hiprtc-builtins0604") - _files_list_append("hiprtc-builtins0700") - _files_list_append("migraphx-hiprtc-driver") - _files_list_append("migraphx") - _files_list_append("migraphx_c") - _files_list_append("migraphx_cpu") - _files_list_append("migraphx_device") - _files_list_append("migraphx_gpu") - _files_list_append("migraphx_onnx") - _files_list_append("migraphx_tf") - if is_dml_package: files_list.append( " Date: Wed, 6 Aug 2025 09:12:03 +0200 Subject: [PATCH 39/46] Use auto Ep version from 'main' branch --- .../migraphx/migraphx_provider_factory.cc | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 41d5dc6ed37b6..cde4d2b65c797 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -245,7 +245,7 @@ struct MigraphXEpFactory : OrtEpFactory { GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; - GetVendorId = GetVendorIdImpl; + CreateAllocator = CreateAllocatorImpl; ReleaseAllocator = ReleaseAllocatorImpl; CreateDataTransfer = CreateDataTransferImpl; @@ -266,21 +266,14 @@ struct MigraphXEpFactory : OrtEpFactory { return factory->vendor.c_str(); } - static const char* GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { - const auto* factory = static_cast(this_ptr); - return factory->version.c_str(); - } - static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); return factory->vendor_id; } - static OrtStatus* CreateDataTransferImpl(OrtEpFactory* this_ptr, - OrtDataTransferImpl** data_transfer) noexcept { - ORT_UNUSED_PARAMETER(this_ptr); - *data_transfer = nullptr; // return nullptr to indicate that this EP does not support data transfer. - return nullptr; + static const char* GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->version.c_str(); } // Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports. @@ -343,6 +336,12 @@ struct MigraphXEpFactory : OrtEpFactory { // should never be called as we don't implement CreateAllocator } + static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* /*this_ptr*/, + OrtDataTransferImpl** data_transfer) noexcept { + *data_transfer = nullptr; // not implemented + return nullptr; + } + static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { return false; } From 781fa34279586befa82b9d2cd16d24d959b7a1f9 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Wed, 6 Aug 2025 09:12:26 +0200 Subject: [PATCH 40/46] Create C# NuGet package for MIGraphX --- tools/ci_build/build.py | 5 ++ .../nuget/generate_nuspec_for_native_nuget.py | 88 ++++++++++++++++++- 2 files changed, 89 insertions(+), 4 deletions(-) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index d1680a4e9a31c..493f754786f17 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1999,6 +1999,7 @@ def build_nuget_package( use_winml, use_qnn, use_dml, + use_migraphx, enable_training_apis, msbuild_extra_options, ): @@ -2036,6 +2037,9 @@ def build_nuget_package( elif use_tensorrt: execution_provider = "/p:ExecutionProvider=tensorrt" package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.TensorRT" + elif use_migraphx: + execution_provider = "/p:ExecutionProvider=migraphx" + package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.MIGraphX" elif use_dnnl: execution_provider = "/p:ExecutionProvider=dnnl" package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.DNNL" @@ -2627,6 +2631,7 @@ def main(): getattr(args, "use_winml", False), args.use_qnn, getattr(args, "use_dml", False), + args.use_migraphx, args.enable_training_apis, normalize_arg_list(args.msbuild_extra_options), ) diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index 211cb7a2a8a75..ead240a7cef1b 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -22,6 +22,8 @@ def get_package_name(os, cpu_arch, ep, is_training_package): pkg_name += "-tensorrt" elif ep == "rocm": pkg_name += "-rocm" + elif ep == "migraphx": + pkg_name += "-migraphx" elif os == "linux": pkg_name += "-linux-" pkg_name += cpu_arch @@ -31,6 +33,8 @@ def get_package_name(os, cpu_arch, ep, is_training_package): pkg_name += "-tensorrt" elif ep == "rocm": pkg_name += "-rocm" + elif ep == "migraphx": + pkg_name += "-migraphx" elif os == "osx": pkg_name = "onnxruntime-osx-" + cpu_arch return pkg_name @@ -44,7 +48,11 @@ def get_package_name(os, cpu_arch, ep, is_training_package): def is_this_file_needed(ep, filename, package_name): if package_name == "Microsoft.ML.OnnxRuntime.Gpu": return False - return (ep != "cuda" or "cuda" in filename) and (ep != "tensorrt" or "cuda" not in filename) + return ( + (ep != "cuda" or "cuda" in filename) + and (ep != "tensorrt" or "cuda" not in filename) + and (ep != "migraphx" or "migraphx" not in filename) + ) # nuget_artifacts_dir: the directory with uncompressed C API tarball/zip files @@ -138,7 +146,7 @@ def parse_arguments(): required=False, default="None", type=str, - choices=["cuda", "dnnl", "openvino", "tensorrt", "snpe", "qnn", "None"], + choices=["cuda", "dnnl", "openvino", "migraphx", "tensorrt", "snpe", "qnn", "None"], help="The selected execution provider for this build.", ) parser.add_argument("--sdk_info", required=False, default="", type=str, help="dependency SDK information.") @@ -182,6 +190,8 @@ def generate_description(line_list, package_name): description = "This package contains Linux native shared library artifacts for ONNX Runtime with CUDA." elif "Microsoft.ML.OnnxRuntime.Gpu.Windows" in package_name: description = "This package contains Windows native shared library artifacts for ONNX Runtime with CUDA." + elif "Microsoft.ML.OnnxRuntime.MIGraphX" in package_name: + description = "This package contains native shared library artifacts for ONNX Runtime with MIGraphX." elif "Intel.ML.OnnxRuntime" in package_name: description = "This package contains native shared library artifacts for ONNX Runtime with OpenVINO." elif "Microsoft.ML.OnnxRuntime" in package_name: # This is a Microsoft.ML.OnnxRuntime.* package @@ -359,6 +369,7 @@ def generate_files(line_list, args): is_windowsai_package = args.package_name == "Microsoft.AI.MachineLearning" is_snpe_package = args.package_name == "Microsoft.ML.OnnxRuntime.Snpe" is_qnn_package = args.package_name == "Microsoft.ML.OnnxRuntime.QNN" + is_migraphx_package = args.package_name == "Microsoft.ML.OnnxRuntime.MIGraphX" is_training_package = args.package_name in [ "Microsoft.ML.OnnxRuntime.Training", "Microsoft.ML.OnnxRuntime.Training.Gpu", @@ -384,6 +395,24 @@ def generate_files(line_list, args): "openvino_ep_shared_lib": "onnxruntime_providers_openvino.dll", "cuda_ep_shared_lib": "onnxruntime_providers_cuda.dll", "qnn_ep_shared_lib": "onnxruntime_providers_qnn.dll", + "migraphx_ep_shared_lib": "onnxruntime_providers_migraphx.dll", + "amd_comgr0602": "amd_comgr0602.dll", + "amd_comgr0604": "amd_comgr0604.dll", + "amd_comgr0700": "amd_comgr0700.dll", + "hiprtc0602": "hiprtc0602.dll", + "hiprtc0604": "hiprtc0604.dll", + "hiprtc0700": "hiprtc0700.dll", + "hiprtc-builtins0602": "hiprtc-builtins0602.dll", + "hiprtc-builtins0604": "hiprtc-builtins0604.dll", + "hiprtc-builtins0700": "hiprtc-builtins0700.dll", + "migraphx-hiprtc-driver": "migraphx-hiprtc-driver.exe", + "migraphx": "migraphx.dll", + "migraphx_c": "migraphx_c.dll", + "migraphx_cpu": "migraphx_cpu.dll", + "migraphx_device": "migraphx_device.dll", + "migraphx_gpu": "migraphx_gpu.dll", + "migraphx_onnx": "migraphx_onnx.dll", + "migraphx_tf": "migraphx_tf", "onnxruntime_perf_test": "onnxruntime_perf_test.exe", "onnx_test_runner": "onnx_test_runner.exe", } @@ -402,6 +431,7 @@ def generate_files(line_list, args): "openvino_ep_shared_lib": "libonnxruntime_providers_openvino.so", "cuda_ep_shared_lib": "libonnxruntime_providers_cuda.so", "rocm_ep_shared_lib": "libonnxruntime_providers_rocm.so", + "migraphx_ep_shared_lib": "libonnxruntime_providers_migraphx.so", "onnxruntime_perf_test": "onnxruntime_perf_test", "onnx_test_runner": "onnx_test_runner", } @@ -421,7 +451,7 @@ def generate_files(line_list, args): include_dir = f"{build_dir}\\native\\include" # Sub.Gpu packages do not include the onnxruntime headers - if args.package_name != "Microsoft.ML.OnnxRuntime.Gpu": + if args.package_name != "Microsoft.ML.OnnxRuntime.Gpu" and args.package_name != "Microsoft.ML.OnnxRuntime.MIGraphX": files_list.append( "' ) + if args.execution_provider == "migraphx": + files_list.append( + "' + ) + files_list.append( + "' + ) + + if is_windows_build: + native_build_path = Path(args.native_build_path) + + def _files_list_append(key: str): + path = native_build_path / nuget_dependencies[key] + if path.exists(): + files_list.append( + "' + ) + + _files_list_append("amd_comgr0602") + _files_list_append("amd_comgr0604") + _files_list_append("amd_comgr0700") + _files_list_append("hiprtc0602") + _files_list_append("hiprtc0604") + _files_list_append("hiprtc0700") + _files_list_append("hiprtc-builtins0602") + _files_list_append("hiprtc-builtins0604") + _files_list_append("hiprtc-builtins0700") + _files_list_append("migraphx-hiprtc-driver") + _files_list_append("migraphx") + _files_list_append("migraphx_c") + _files_list_append("migraphx_cpu") + _files_list_append("migraphx_device") + _files_list_append("migraphx_gpu") + _files_list_append("migraphx_onnx") + _files_list_append("migraphx_tf") + if is_dml_package: files_list.append( " Date: Thu, 7 Aug 2025 12:53:42 +0200 Subject: [PATCH 41/46] fix pybind compilation after merging 'main' --- .../python/onnxruntime_pybind_mlvalue.cc | 14 +++++++------- .../python/onnxruntime_pybind_mlvalue.h | 8 ++++---- .../python/onnxruntime_pybind_ortvalue.cc | 19 ++++++++++--------- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index e363929708d3e..1934e0eda7956 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -196,9 +196,9 @@ void CudaToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { GetProviderInfo_CUDA().cudaMemcpy_DeviceToHost(dst, src, num_bytes); } -const std::unordered_map* GetCudaToHostMemCpyFunction() { +const std::unordered_map* GetCudaToHostMemCpyFunction(const OrtDevice& device) { static std::unordered_map map{ - {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, 0}, CudaToCpuMemCpy}, + {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, device.Id()}, CudaToCpuMemCpy}, }; return ↦ @@ -257,7 +257,7 @@ void MIGraphXToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { const std::unordered_map* GetMIGraphXToHostMemCpyFunction(const OrtDevice& device) { static std::unordered_map map{ - {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, 0}, MIGraphXToCpuMemCpy}, + {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device.Id()}, MIGraphXToCpuMemCpy}, }; return ↦ @@ -379,9 +379,9 @@ void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { D3D12_RESOURCE_STATE_UNORDERED_ACCESS); } -const std::unordered_map* GetDmlToHostMemCpyFunction() { +const std::unordered_map* GetDmlToHostMemCpyFunction(const OrtDevice& device) { static std::unordered_map map{ - {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, 0}, DmlToCpuMemCpy}, + {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, device.Id()}, DmlToCpuMemCpy}, }; return ↦ @@ -449,9 +449,9 @@ void RocmToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { GetProviderInfo_ROCM().rocmMemcpy_DeviceToHost(dst, src, num_bytes); } -const std::unordered_map* GetRocmToHostMemCpyFunction() { +const std::unordered_map* GetRocmToHostMemCpyFunction(const OrtDevice& device) { static std::unordered_map map{ - {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, 0}, RocmToCpuMemCpy}, + {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device.Id()}, RocmToCpuMemCpy}, }; return ↦ diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.h b/onnxruntime/python/onnxruntime_pybind_mlvalue.h index 7b65c0aae45c1..eba783d826212 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.h +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.h @@ -74,7 +74,7 @@ void CpuToCudaMemCpy(void* dst, const void* src, size_t num_bytes); void CudaToCpuMemCpy(void* dst, const void* src, size_t num_bytes); -const std::unordered_map* GetCudaToHostMemCpyFunction(); +const std::unordered_map* GetCudaToHostMemCpyFunction(const OrtDevice&); bool IsCudaDeviceIdValid(const onnxruntime::logging::Logger& logger, int id); @@ -92,7 +92,7 @@ void CpuToDmlMemCpy(void* dst, const void* src, size_t num_bytes); void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes); -const std::unordered_map* GetDmlToHostMemCpyFunction(); +const std::unordered_map* GetDmlToHostMemCpyFunction(const OrtDevice&); #endif @@ -102,7 +102,7 @@ void CpuToMIGraphXMemCpy(void* dst, const void* src, size_t num_bytes); void MIGraphXToCpuMemCpy(void* dst, const void* src, size_t num_bytes); -const std::unordered_map* GetMIGraphXToHostMemCpyFunction(); +const std::unordered_map* GetMIGraphXToHostMemCpyFunction(const OrtDevice&); AllocatorPtr GetMIGraphXAllocator(OrtDevice::DeviceId id); @@ -132,7 +132,7 @@ void CpuToRocmMemCpy(void* dst, const void* src, size_t num_bytes); void RocmToCpuMemCpy(void* dst, const void* src, size_t num_bytes); -const std::unordered_map* GetRocmToHostMemCpyFunction(); +const std::unordered_map* GetRocmToHostMemCpyFunction(const OrtDevice&); #endif diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index c33bee45d0237..9a1c415b9a365 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -425,21 +425,22 @@ void addOrtValueMethods(pybind11::module& m) { switch (device.Vendor()) { #ifdef USE_CUDA case OrtDevice::VendorIds::NVIDIA: - return GetPyObjFromTensor(*ml_value, nullptr, GetCudaToHostMemCpyFunction()); + return GetPyObjFromTensor(*ml_value, nullptr, GetCudaToHostMemCpyFunction(device)); #endif -#ifdef USE_MIGRAPHX - case OrtDevice::VendorIds::AMD: - return GetPyObjFromTensor(*ml_value, nullptr, GetMIGraphXToHostMemCpyFunction()); +#ifdef USE_CANN + case OrtDevice::VendorIds::HUAWEI: + return GetPyObjFromTensor(*ml_value, nullptr, GetCannToHostMemCpyFunction()); #endif + #ifdef USE_DML case OrtDevice::VendorIds::MICROSOFT: - return GetPyObjFromTensor(*ml_value, nullptr, GetDmlToHostMemCpyFunction()); + return GetPyObjFromTensor(*ml_value, nullptr, GetDmlToHostMemCpyFunction(device)); #endif -#ifdef USE_CANN - case OrtDevice::VendorIds::HUAWEI: - return GetPyObjFromTensor(*ml_value, nullptr, GetCannToHostMemCpyFunction()); +#ifdef USE_MIGRAPHX + case OrtDevice::VendorIds::AMD: + return GetPyObjFromTensor(*ml_value, nullptr, GetMIGraphXToHostMemCpyFunction(device)); #endif - default: + default: return GetPyObjFromTensor(*ml_value, nullptr, nullptr); } }) #if defined(ENABLE_DLPACK) From 7e38f717ad9016280b19acc8291cbd3a30c297ea Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Thu, 7 Aug 2025 16:19:06 +0200 Subject: [PATCH 42/46] do not alter OrtMIGraphXProvierOptions --- .../core/session/onnxruntime_c_api.h | 7 +- .../migraphx/migraphx_execution_provider.cc | 12 +- .../migraphx/migraphx_execution_provider.h | 16 ++- .../migraphx_execution_provider_info.cc | 18 +-- .../migraphx_execution_provider_info.h | 4 - .../migraphx/migraphx_provider_factory.cc | 55 ++------- .../migraphx/migraphx_provider_factory.h | 4 +- .../core/session/provider_bridge_ort.cc | 9 +- .../python/onnxruntime_pybind_state.cc | 112 +----------------- onnxruntime/test/util/default_providers.cc | 23 +--- 10 files changed, 46 insertions(+), 214 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 4b4423b1465d9..6eb15280a4aa4 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -775,11 +775,8 @@ typedef struct OrtMIGraphXProviderOptions { * \note If a ::OrtArenaCfg has been applied, it will override this field */ int migraphx_arena_extend_strategy; - int migraphx_bf16_enable; // MIGraphX BF16 precision. Default 0 = false, nonzero = true - const ORTCHAR_T* migraphx_cache_dir; // MIGraphX model cache directory - void* migraphx_external_alloc; // Pointer to an external Alloc() function (default is none) - void* migraphx_external_free; // Pointer to an external Free() function (default is none) - void* migraphx_external_empty_cache; // Pointer to an external EmptyCache() function (default is none) + + // This is the legacy struct and don't add new fields here. } OrtMIGraphXProviderOptions; /** \brief OpenVINO Provider Options diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 5e2d611443918..a59347841be95 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -203,9 +203,9 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv } if (int8_enable_ ^ fp8_enable_) { - int8_calibration_cache_name_ = + int8_calibration_table_name_ = int8_calibration_cache_name_env.empty() ? info.int8_calibration_table_name : int8_calibration_cache_name_env; - int8_use_native_migraphx_calibration_table_ = + int8_use_native_calibration_table_ = int8_use_native_migraphx_calibration_table_env.empty() ? info.int8_use_native_calibration_table : std::stoi(int8_use_native_migraphx_calibration_table_env) != 0; } @@ -216,8 +216,8 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv // Load INT8 calibration table if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { std::unordered_map dynamic_range_map; - auto calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); - if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { + auto calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_table_name_); + if (!ReadDynamicRange(calibration_cache_path, int8_use_native_calibration_table_, dynamic_range_map)) { throw std::runtime_error("Session Failed to read INT8 calibration table " + calibration_cache_path.string()); } } @@ -234,9 +234,9 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv << "\n " << migraphx_provider_option::kArenaExtendStrategy << ": " << GetArenaExtendStrategyName(arena_extend_strategy_) << "\n dump_model_ops: " << dump_model_ops_ << "\n " << migraphx_provider_option::kExhaustiveTune << ": " << exhaustive_tune_ - << "\n " << migraphx_provider_option::kInt8CalibTable << ": " << int8_calibration_cache_name_ + << "\n " << migraphx_provider_option::kInt8CalibTable << ": " << int8_calibration_table_name_ << "\n int8_calibration_cache_available: " << int8_calibration_cache_available_ - << "\n " << migraphx_provider_option::kInt8UseNativeCalibTable << ": " << int8_use_native_migraphx_calibration_table_ + << "\n " << migraphx_provider_option::kInt8UseNativeCalibTable << ": " << int8_use_native_calibration_table_ << "\n " << migraphx_provider_option::kModelCacheDir << ": " << model_cache_path_; } diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 9d9e1d0e1dd1e..99f790b9f9f7a 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -31,10 +31,6 @@ constexpr auto kDumpModelOps = "ORT_MIGRAPHX_DUMP_MODEL_OPS"sv; constexpr auto kINT8CalibrationTableName = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"sv; constexpr auto kCachePath = "ORT_MIGRAPHX_CACHE_PATH"sv; constexpr auto kINT8UseNativeMIGraphXCalibrationTable = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"sv; -constexpr auto kSaveCompiledModel = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL"sv; -constexpr auto kSavedModelPath = "ORT_MIGRAPHX_SAVE_COMPILED_PATH"sv; -constexpr auto kLoadCompiledModel = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL"sv; -constexpr auto kLoadModelPath = "ORT_MIGRAPHX_LOAD_COMPILED_PATH"sv; constexpr auto kExhaustiveTune = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"sv; constexpr auto kModelCachePath = "ORT_MIGRAPHX_MODEL_CACHE_PATH"sv; } // namespace migraphx_env_vars @@ -99,13 +95,15 @@ class MIGraphXExecutionProvider : public IExecutionProvider { {std::string{migraphx_provider_option::kBf16Enable}, MakeStringWithClassicLocale(bf16_enable_)}, {std::string{migraphx_provider_option::kFp8Enable}, MakeStringWithClassicLocale(fp8_enable_)}, {std::string{migraphx_provider_option::kInt8Enable}, MakeStringWithClassicLocale(int8_enable_)}, - {std::string{migraphx_provider_option::kModelCacheDir}, MakeStringWithClassicLocale(model_cache_path_)}, + {std::string{migraphx_provider_option::kInt8CalibTable}, MakeStringWithClassicLocale(int8_calibration_table_name_)}, + {std::string{migraphx_provider_option::kInt8UseNativeCalibTable}, MakeStringWithClassicLocale(int8_use_native_calibration_table_)}, + {std::string{migraphx_provider_option::kExhaustiveTune}, MakeStringWithClassicLocale(exhaustive_tune_)}, {std::string{migraphx_provider_option::kMemLimit}, MakeStringWithClassicLocale(mem_limit_)}, {std::string{migraphx_provider_option::kArenaExtendStrategy}, EnumToName(arena_extend_strategy_mapping, arena_extend_strategy_)}, - {std::string{migraphx_provider_option::kExhaustiveTune}, MakeStringWithClassicLocale(exhaustive_tune_)}, {std::string{migraphx_provider_option::kGpuExternalAlloc}, MakeStringWithClassicLocale(external_alloc_)}, {std::string{migraphx_provider_option::kGpuExternalFree}, MakeStringWithClassicLocale(external_free_)}, - {std::string{migraphx_provider_option::kGpuExternalEmptyCache}, MakeStringWithClassicLocale(external_empty_cache_)}}; + {std::string{migraphx_provider_option::kGpuExternalEmptyCache}, MakeStringWithClassicLocale(external_empty_cache_)}, + {std::string{migraphx_provider_option::kModelCacheDir}, MakeStringWithClassicLocale(model_cache_path_)}}; } private: @@ -114,9 +112,9 @@ class MIGraphXExecutionProvider : public IExecutionProvider { bool bf16_enable_ = false; bool fp8_enable_ = false; bool int8_enable_ = false; - std::string int8_calibration_cache_name_; + std::string int8_calibration_table_name_; bool int8_calibration_cache_available_ = false; - bool int8_use_native_migraphx_calibration_table_ = false; + bool int8_use_native_calibration_table_ = false; std::filesystem::path calibration_cache_path_{}; std::unordered_map dynamic_range_map_; std::filesystem::path model_cache_path_{}; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index 77a0d8014b678..33ef366eb18e5 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -57,11 +57,18 @@ MIGraphXExecutionProviderInfo::MIGraphXExecutionProviderInfo(const ProviderOptio external_empty_cache = reinterpret_cast(address); return Status::OK(); }) + .AddValueParser( + migraphx_provider_option::kModelCacheDir, + [this](const std::string& value_str) -> Status { + model_cache_dir = ToPathString(value_str); + return Status::OK(); + }) .AddAssignmentToReference(migraphx_provider_option::kFp16Enable, fp16_enable) .AddAssignmentToReference(migraphx_provider_option::kBf16Enable, bf16_enable) .AddAssignmentToReference(migraphx_provider_option::kFp8Enable, fp8_enable) .AddAssignmentToReference(migraphx_provider_option::kInt8Enable, int8_enable) - .AddAssignmentToReference(migraphx_provider_option::kModelCacheDir, model_cache_dir) + .AddAssignmentToReference(migraphx_provider_option::kInt8UseNativeCalibTable, int8_use_native_calibration_table) + .AddAssignmentToReference(migraphx_provider_option::kInt8CalibTable, int8_calibration_table_name) .AddAssignmentToReference(migraphx_provider_option::kExhaustiveTune, exhaustive_tune) .AddAssignmentToReference(migraphx_provider_option::kMemLimit, mem_limit) .AddAssignmentToEnumReference(migraphx_provider_option::kArenaExtendStrategy, arena_extend_strategy_mapping, arena_extend_strategy) @@ -71,16 +78,11 @@ MIGraphXExecutionProviderInfo::MIGraphXExecutionProviderInfo(const ProviderOptio MIGraphXExecutionProviderInfo::MIGraphXExecutionProviderInfo(const OrtMIGraphXProviderOptions& options) noexcept : device_id{static_cast(options.device_id)}, fp16_enable{options.migraphx_fp16_enable != 0}, - bf16_enable{options.migraphx_bf16_enable != 0}, fp8_enable{options.migraphx_fp8_enable != 0}, int8_enable{options.migraphx_int8_enable != 0}, - model_cache_dir{options.migraphx_cache_dir}, exhaustive_tune{options.migraphx_exhaustive_tune != 0}, mem_limit{options.migraphx_mem_limit}, - arena_extend_strategy{options.migraphx_arena_extend_strategy}, - external_alloc{options.migraphx_external_alloc}, - external_free{options.migraphx_external_free}, - external_empty_cache{options.migraphx_external_empty_cache} { + arena_extend_strategy{options.migraphx_arena_extend_strategy} { } ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions() const { @@ -90,6 +92,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions() const { {std::string{migraphx_provider_option::kBf16Enable}, MakeStringWithClassicLocale(bf16_enable)}, {std::string{migraphx_provider_option::kFp8Enable}, MakeStringWithClassicLocale(fp8_enable)}, {std::string{migraphx_provider_option::kInt8Enable}, MakeStringWithClassicLocale(int8_enable)}, + {std::string{migraphx_provider_option::kInt8CalibTable}, MakeStringWithClassicLocale(int8_calibration_table_name)}, + {std::string{migraphx_provider_option::kInt8UseNativeCalibTable}, MakeStringWithClassicLocale(int8_use_native_calibration_table)}, {std::string{migraphx_provider_option::kMemLimit}, MakeStringWithClassicLocale(mem_limit)}, {std::string{migraphx_provider_option::kArenaExtendStrategy}, EnumToName(arena_extend_strategy_mapping, arena_extend_strategy)}, {std::string{migraphx_provider_option::kExhaustiveTune}, MakeStringWithClassicLocale(exhaustive_tune)}, diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index f08201a3aff06..414254aaa2629 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -27,10 +27,6 @@ constexpr auto kFp8Enable = "migraphx_fp8_enable"sv; constexpr auto kInt8Enable = "migraphx_int8_enable"sv; constexpr auto kInt8CalibTable = "migraphx_int8_calibration_table_name"sv; constexpr auto kInt8UseNativeCalibTable = "migraphx_int8_use_native_calibration_table"sv; -constexpr auto kSaveCompiledModel = "migraphx_save_compiled_model"sv; -constexpr auto kSaveModelPath = "migraphx_save_model_name"sv; -constexpr auto kLoadCompiledModel = "migraphx_load_compiled_model"sv; -constexpr auto kLoadModelPath = "migraphx_load_model_name"sv; constexpr auto kExhaustiveTune = "migraphx_exhaustive_tune"sv; constexpr auto kMemLimit = "migraphx_mem_limit"sv; constexpr auto kArenaExtendStrategy = "migraphx_arena_extend_strategy"sv; diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index cde4d2b65c797..ef8b81631c3fa 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -97,9 +97,11 @@ struct ProviderInfo_MIGraphX_Impl final : ProviderInfo_MIGraphX { } } g_info; -struct MIGraphX_Provider : Provider { +struct MIGraphX_Provider final : Provider { void* GetInfo() override { return &g_info; } + virtual ~MIGraphX_Provider() = default; + std::shared_ptr CreateExecutionProviderFactory(int device_id) override { MIGraphXExecutionProviderInfo info; info.device_id = static_cast(device_id); @@ -107,28 +109,13 @@ struct MIGraphX_Provider : Provider { return std::make_shared(info); } + // Method uses ProviderOptions, and not OrtMIGraphXProviderOptions (obsolete) std::shared_ptr CreateExecutionProviderFactory(const void* provider_options) override { - auto& options = *static_cast(provider_options); - MIGraphXExecutionProviderInfo info; - info.device_id = static_cast(options.device_id); - info.target_device = "gpu"; - info.fp16_enable = options.migraphx_fp16_enable; - info.bf16_enable = options.migraphx_bf16_enable; - info.fp8_enable = options.migraphx_fp8_enable; - info.exhaustive_tune = options.migraphx_exhaustive_tune; - info.int8_enable = options.migraphx_int8_enable; - info.int8_calibration_table_name = ""; - if (options.migraphx_int8_calibration_table_name != nullptr) { - info.int8_calibration_table_name = options.migraphx_int8_calibration_table_name; - } - info.int8_use_native_calibration_table = options.migraphx_use_native_calibration_table != 0; - info.model_cache_dir = ""; - if (options.migraphx_cache_dir != nullptr) { - info.model_cache_dir = options.migraphx_cache_dir; + if (provider_options != nullptr) { + return std::make_shared(MIGraphXExecutionProviderInfo{ + MIGraphXExecutionProviderInfo{*static_cast(provider_options)}}); } - info.arena_extend_strategy = static_cast(options.migraphx_arena_extend_strategy); - info.mem_limit = options.migraphx_mem_limit; - return std::make_shared(info); + return nullptr; } void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override { @@ -136,7 +123,6 @@ struct MIGraphX_Provider : Provider { const auto migx_options = static_cast(provider_options); migx_options->device_id = internal_options.device_id; migx_options->migraphx_fp16_enable = internal_options.fp16_enable; - migx_options->migraphx_bf16_enable = internal_options.bf16_enable; migx_options->migraphx_fp8_enable = internal_options.fp8_enable; migx_options->migraphx_int8_enable = internal_options.int8_enable; migx_options->migraphx_exhaustive_tune = internal_options.exhaustive_tune; @@ -157,30 +143,13 @@ struct MIGraphX_Provider : Provider { migx_options->migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table; - if (internal_options.model_cache_dir.empty()) { - migx_options->migraphx_cache_dir = nullptr; - } else { - const auto cache_dir_str{internal_options.model_cache_dir.native()}; - auto cache_dir = new ORTCHAR_T[cache_dir_str.size() + 1]; -#ifdef _MSC_VER - wcsncpy_s(cache_dir, cache_dir_str.size() + 1, cache_dir_str.data(), cache_dir_str.size()); -#else - strncpy(cache_dir, cache_dir_str.data(), cache_dir_str.size()); -#endif - cache_dir[cache_dir_str.size()] = '\0'; - migx_options->migraphx_cache_dir = cache_dir; - } - migx_options->migraphx_arena_extend_strategy = static_cast(internal_options.arena_extend_strategy); migx_options->migraphx_mem_limit = internal_options.mem_limit; - - migx_options->migraphx_external_alloc = internal_options.external_alloc; - migx_options->migraphx_external_free = internal_options.external_free; - migx_options->migraphx_external_empty_cache = internal_options.external_empty_cache; } ProviderOptions GetProviderOptions(const void* provider_options) override { - return MIGraphXExecutionProviderInfo{*static_cast(provider_options)}.ToProviderOptions(); + return provider_options != nullptr ? MIGraphXExecutionProviderInfo{ + *static_cast(provider_options)}.ToProviderOptions() : ProviderOptions{}; } Status CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, @@ -191,9 +160,7 @@ struct MIGraphX_Provider : Provider { const OrtLogger& logger, std::unique_ptr& ep) override { ORT_UNUSED_PARAMETER(num_devices); - OrtMIGraphXProviderOptions migraphx_options; - UpdateProviderOptions(&migraphx_options, provider_options); - const auto ep_factory = CreateExecutionProviderFactory(&migraphx_options); + const auto ep_factory = CreateExecutionProviderFactory(&provider_options); ep = ep_factory->CreateProvider(session_options, logger); return Status::OK(); } diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h index a19fa7a87fec1..c23c9947c8d9b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h @@ -12,8 +12,8 @@ namespace onnxruntime { class IAllocator; struct ProviderInfo_MIGraphX { - virtual std::unique_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, const char* name) = 0; - virtual std::unique_ptr CreateMIGraphXPinnedAllocator(OrtDevice::DeviceId device_id, const char* name) = 0; + virtual std::unique_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, const char* name) = 0; + virtual std::unique_ptr CreateMIGraphXPinnedAllocator(OrtDevice::DeviceId device_id, const char* name) = 0; virtual void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, size_t count) = 0; virtual void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, size_t count) = 0; virtual std::shared_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t mem_limit, diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index ddf9241795cb8..ee59ff2ab4932 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -2109,13 +2109,12 @@ std::shared_ptr NvProviderFactoryCreator::Create( } std::shared_ptr MIGraphXProviderFactoryCreator::Create(const ProviderOptions& provider_options) { - OrtMIGraphXProviderOptions migraphx_options; - s_library_migraphx.Get().UpdateProviderOptions(&migraphx_options, provider_options); - return s_library_migraphx.Get().CreateExecutionProviderFactory(&migraphx_options); + return s_library_migraphx.Get().CreateExecutionProviderFactory(&provider_options); } -std::shared_ptr MIGraphXProviderFactoryCreator::Create(const OrtMIGraphXProviderOptions* provider_options) { - return s_library_migraphx.Get().CreateExecutionProviderFactory(provider_options); +std::shared_ptr MIGraphXProviderFactoryCreator::Create(const OrtMIGraphXProviderOptions* options) { + const auto provider_options{s_library_migraphx.Get().GetProviderOptions(options)}; + return s_library_migraphx.Get().CreateExecutionProviderFactory(&provider_options); } // Adapter to convert the legacy OrtOpenVINOProviderOptions to ProviderOptions diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 72bf33079f216..24554560b4dde 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -979,120 +979,10 @@ static std::shared_ptr CreateExecutionProviderFactory #endif } else if (type == kMIGraphXExecutionProvider) { #if defined(USE_MIGRAPHX) || defined(USE_MIGRAPHX_PROVIDER_INTERFACE) - std::string calibration_table; - PathString model_cache_path; auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { - OrtMIGraphXProviderOptions params{ - 0, - 0, - 0, - 0, - 0, - nullptr, - 0, - nullptr, - 0, - nullptr, - 1, - SIZE_MAX, - 0, - 0, - nullptr, - nullptr, - nullptr, - nullptr}; - for (auto option : it->second) { - if (option.first == "device_id") { - if (!option.second.empty()) { - params.device_id = std::stoi(option.second); - } else { - ORT_THROW("[ERROR] [MIGraphX] The value for the key 'device_id' should be a number i.e. '0'.\n"); - } - } else if (option.first == migraphx_provider_option::kFp16Enable) { - if (option.second == "True" || option.second == "true") { - params.migraphx_fp16_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_fp16_enable = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_fp16_enable' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == migraphx_provider_option::kBf16Enable) { - if (option.second == "True" || option.second == "true") { - params.migraphx_bf16_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_bf16_enable = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_bf16_enable' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == migraphx_provider_option::kFp8Enable) { - if (option.second == "True" || option.second == "true") { - params.migraphx_fp8_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_fp8_enable = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_fp8_enable' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == migraphx_provider_option::kInt8Enable) { - if (option.second == "True" || option.second == "true") { - params.migraphx_int8_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_int8_enable = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_int8_enable' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == migraphx_provider_option::kInt8CalibTable) { - if (!option.second.empty()) { - calibration_table = option.second; - params.migraphx_int8_calibration_table_name = calibration_table.c_str(); - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_int8_calibration_table_name' should be a " - "file name i.e. 'cal_table'.\n"); - } - } else if (option.first == migraphx_provider_option::kInt8UseNativeCalibTable) { - if (option.second == "True" || option.second == "true") { - params.migraphx_use_native_calibration_table = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_use_native_calibration_table = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_use_native_calibration_table' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == migraphx_provider_option::kModelCacheDir) { - if (!option.second.empty()) { - model_cache_path = ToPathString(option.second); - params.migraphx_cache_dir = model_cache_path.c_str(); - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_load_model_name' should be a " - "file name i.e. 'compiled_model.mxr'.\n"); - } - } else if (option.first == migraphx_provider_option::kExhaustiveTune) { - if (option.second == "True" || option.second == "true") { - params.migraphx_exhaustive_tune = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_exhaustive_tune = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_exhaustive_tune' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else { - ORT_THROW("Invalid MIGraphX EP option: ", option.first); - } - } if (std::shared_ptr migraphx_provider_factory = - onnxruntime::MIGraphXProviderFactoryCreator::Create(¶ms)) { + onnxruntime::MIGraphXProviderFactoryCreator::Create(it->second)) { return migraphx_provider_factory; } } else { diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 8df845d7ea5d6..bae7a14908916 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -80,26 +80,7 @@ std::unique_ptr TensorrtExecutionProviderWithOptions(const O std::unique_ptr DefaultMIGraphXExecutionProvider() { #ifdef USE_MIGRAPHX - constexpr OrtMIGraphXProviderOptions params{ - 0, - 0, - 0, - 0, - 0, - nullptr, - 0, - nullptr, - 0, - nullptr, - 1, - SIZE_MAX, - 0, - 0, - nullptr, - nullptr, - nullptr, - nullptr}; - return MIGraphXProviderFactoryCreator::Create(¶ms)->CreateProvider(); + return MIGraphXProviderFactoryCreator::Create(ProviderOptions{})->CreateProvider(); #else return nullptr; #endif @@ -107,7 +88,7 @@ std::unique_ptr DefaultMIGraphXExecutionProvider() { std::unique_ptr MIGraphXExecutionProviderWithOptions(const OrtMIGraphXProviderOptions* params) { #ifdef USE_MIGRAPHX - if (auto factory = MIGraphXProviderFactoryCreator::Create(params)) + if (const auto factory = MIGraphXProviderFactoryCreator::Create(params); factory != nullptr) return factory->CreateProvider(); #else ORT_UNUSED_PARAMETER(params); From 8be40b0b5207ba8dda9106bd33d9979872d436fb Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Thu, 7 Aug 2025 17:28:05 +0200 Subject: [PATCH 43/46] remove nested MIGraphXExecutionProviderInfo --- .../core/providers/migraphx/migraphx_provider_factory.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index ef8b81631c3fa..914ea3f375fe3 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -112,8 +112,9 @@ struct MIGraphX_Provider final : Provider { // Method uses ProviderOptions, and not OrtMIGraphXProviderOptions (obsolete) std::shared_ptr CreateExecutionProviderFactory(const void* provider_options) override { if (provider_options != nullptr) { - return std::make_shared(MIGraphXExecutionProviderInfo{ - MIGraphXExecutionProviderInfo{*static_cast(provider_options)}}); + return std::make_shared( + MIGraphXExecutionProviderInfo{*static_cast(provider_options)} + ); } return nullptr; } From b9d7e267e94d4f64a968f504d9c7a0aad7170600 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Thu, 7 Aug 2025 17:45:27 +0200 Subject: [PATCH 44/46] lintrunner --- .../core/providers/migraphx/migraphx_provider_factory.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 914ea3f375fe3..37c76c9f46846 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -113,8 +113,7 @@ struct MIGraphX_Provider final : Provider { std::shared_ptr CreateExecutionProviderFactory(const void* provider_options) override { if (provider_options != nullptr) { return std::make_shared( - MIGraphXExecutionProviderInfo{*static_cast(provider_options)} - ); + MIGraphXExecutionProviderInfo{*static_cast(provider_options)}); } return nullptr; } @@ -149,8 +148,10 @@ struct MIGraphX_Provider final : Provider { } ProviderOptions GetProviderOptions(const void* provider_options) override { - return provider_options != nullptr ? MIGraphXExecutionProviderInfo{ - *static_cast(provider_options)}.ToProviderOptions() : ProviderOptions{}; + return provider_options != nullptr ? MIGraphXExecutionProviderInfo{ + *static_cast(provider_options)} + .ToProviderOptions() + : ProviderOptions{}; } Status CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, From 481c854a7dcfa6fbcee81bb18b747bf07989c0ae Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Fri, 8 Aug 2025 01:34:28 +0200 Subject: [PATCH 45/46] disable C4065 warning for the switch statement --- onnxruntime/python/onnxruntime_pybind_ortvalue.cc | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index 9a1c415b9a365..12205c08723e6 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -422,6 +422,13 @@ void addOrtValueMethods(pybind11::module& m) { .def("numpy", [](const OrtValue* ml_value) -> py::object { ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are convertible to Numpy objects"); const auto& device = ml_value->Get().Location().device; +#ifdef _MSC_VER +// The switch statement may only contain the 'default' label. In such a case, the MSVC compiler +// will warn about it, and since the warnings are treated as errors, the compilation will break. +// Below pragmas turn off warning generation for this switch only. +#pragma warning(push) +#pragma warning(disable : 4065) +#endif switch (device.Vendor()) { #ifdef USE_CUDA case OrtDevice::VendorIds::NVIDIA: @@ -440,9 +447,13 @@ void addOrtValueMethods(pybind11::module& m) { case OrtDevice::VendorIds::AMD: return GetPyObjFromTensor(*ml_value, nullptr, GetMIGraphXToHostMemCpyFunction(device)); #endif - default: + default: return GetPyObjFromTensor(*ml_value, nullptr, nullptr); - } }) + } +#ifdef _MSC_VER +#pragma warning(pop) +#endif + }) #if defined(ENABLE_DLPACK) .def("to_dlpack", [](OrtValue* ort_value) -> py::object { return py::reinterpret_steal(ToDlpack(*ort_value)); }, "Returns a DLPack representing the tensor. This method does not copy the pointer shape, " From 1ac92bc04d8401d6d9d354cc14e33e374f7b85f1 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Fri, 8 Aug 2025 01:53:57 +0200 Subject: [PATCH 46/46] Update onnxruntime/python/onnxruntime_pybind_ortvalue.cc Co-authored-by: Scott McKay --- onnxruntime/python/onnxruntime_pybind_ortvalue.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index 12205c08723e6..1fe7ab0884f9c 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -421,7 +421,7 @@ void addOrtValueMethods(pybind11::module& m) { // Converts Tensor into a numpy array .def("numpy", [](const OrtValue* ml_value) -> py::object { ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are convertible to Numpy objects"); - const auto& device = ml_value->Get().Location().device; + [[maybe_unused]] const auto& device = ml_value->Get().Location().device; #ifdef _MSC_VER // The switch statement may only contain the 'default' label. In such a case, the MSVC compiler // will warn about it, and since the warnings are treated as errors, the compilation will break.