From 27b24aa32b1f4ab6e4e382a7d9fd3cc6aa4aec42 Mon Sep 17 00:00:00 2001 From: "Chunye Wang@AMD" Date: Fri, 8 Aug 2025 18:07:00 -0500 Subject: [PATCH 1/2] [VitisAI] bugfix model_clone optimization (#25629) ### Description It is related to #25320 #23979. Enable tensor raw data sharing for externalized tensor proto with kTensorProtoMemoryAddressTag ### Motivation and Context With #25320 #23979, all initialized tensor protos are associated with OrtValue, VitisiAI EP need to adapt to this change. Co-authored-by: mingyue --- onnxruntime/core/providers/vitisai/imp/graph.cc | 11 +++++++++-- .../core/providers/vitisai/imp/tensor_proto.cc | 2 +- onnxruntime/core/providers/vitisai/imp/tensor_proto.h | 1 + 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/vitisai/imp/graph.cc b/onnxruntime/core/providers/vitisai/imp/graph.cc index 20ae1cfbfa2c1..c6bf29dafa184 100644 --- a/onnxruntime/core/providers/vitisai/imp/graph.cc +++ b/onnxruntime/core/providers/vitisai/imp/graph.cc @@ -15,7 +15,7 @@ #include "vaip/node.h" #include "vaip/node_arg.h" - +#include "./tensor_proto.h" namespace vaip { struct NodeEdgeT { @@ -286,7 +286,14 @@ Model* model_clone(const Model& original_model, int64_t external_data_threshold) cloned_tensor->add_dims(dim); size = size * dim; } - if (size >= external_data_threshold) { + auto ORT_MEM_ADDR_tag = process_ext_address(*original_tensor); + if (!ORT_MEM_ADDR_tag.empty()) { + cloned_tensor->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + auto external_data = cloned_tensor->mutable_external_data(); + auto p = external_data->Add(); + *p->mutable_key() = "location"; + *p->mutable_value() = std::string("<") + graph_ptr; + } else if (size >= external_data_threshold) { cloned_tensor->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); auto external_data = cloned_tensor->mutable_external_data(); auto p = external_data->Add(); diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc index bb942c69003a1..2f1478bf1326b 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc @@ -10,7 +10,7 @@ namespace vaip { using namespace onnxruntime; -static gsl::span process_ext_address(const ONNX_NAMESPACE::TensorProto& tensor) { +gsl::span process_ext_address(const ONNX_NAMESPACE::TensorProto& tensor) { auto tensor_proto = const_cast(&tensor); auto file = std::string(); uintptr_t offset = 0; diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.h b/onnxruntime/core/providers/vitisai/imp/tensor_proto.h index 73015d3411a54..a7c90ac18b44e 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.h +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.h @@ -37,4 +37,5 @@ ONNX_NAMESPACE::TensorProto* tensor_proto_new_fp16(const std::string& name, cons const std::vector& data); ONNX_NAMESPACE::TensorProto* tensor_proto_new_doubles(const std::string& name, const std::vector& shape, const std::vector& data); +gsl::span process_ext_address(const ONNX_NAMESPACE::TensorProto& tensor); } // namespace vaip From 10499fdfdd686bbea7ecf7890db1854c3f84bc6e Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Sat, 9 Aug 2025 02:10:14 +0200 Subject: [PATCH 2/2] [MIGraphX EP] Syncing AMD changes upstream (#25583) A set of changes required for WCR/WindowsML that were added to the MIGraphX Execution provider. The development was done in the ROCm repository, now we want to sync with the main branch with a single drop. The PR incorporates the review comments from the previous closed PR #25338. Motivation and Context Fixes, changes, and updates to MIGraphX EP that have been done for ROCm development. Pushing this back upstream to ensure mainline onnxruntime is using the latest changes. Moving forward, MIGraphX EP will be cut from the latest official release tag as a base point while also adding additional features that will be contributed back. --------- Co-authored-by: urpetkov-amd <127323899+urpetkov-amd@users.noreply.github.com> Co-authored-by: Ted Themistokleous <107195283+TedThemistokleous@users.noreply.github.com> Co-authored-by: Ted Themistokleous Co-authored-by: Scott McKay --- .gitignore | 1 + cmake/CMakeLists.txt | 1 + cmake/onnxruntime_providers_migraphx.cmake | 81 ++- cmake/onnxruntime_python.cmake | 15 + cmake/onnxruntime_unittests.cmake | 4 - include/onnxruntime/core/common/common.h | 20 +- .../onnxruntime/core/common/string_helper.h | 6 +- .../core/framework/provider_options_utils.h | 38 +- .../core/session/onnxruntime_c_api.h | 9 +- onnxruntime/core/common/helper.cc | 4 +- onnxruntime/core/common/path_string.h | 14 + .../providers/migraphx/gpu_data_transfer.h | 2 +- .../providers/migraphx/migraphx_allocator.cc | 11 +- .../providers/migraphx/migraphx_allocator.h | 2 +- .../core/providers/migraphx/migraphx_call.cc | 36 +- .../core/providers/migraphx/migraphx_call.h | 4 +- .../migraphx/migraphx_execution_provider.cc | 538 +++++++++--------- .../migraphx/migraphx_execution_provider.h | 98 ++-- .../migraphx_execution_provider_info.cc | 151 ++--- .../migraphx_execution_provider_info.h | 95 ++-- .../migraphx_execution_provider_utils.h | 155 ++++- .../core/providers/migraphx/migraphx_inc.h | 1 + .../migraphx/migraphx_provider_factory.cc | 160 +++--- .../migraphx/migraphx_provider_factory.h | 23 +- .../migraphx_provider_factory_creator.h | 2 + .../migraphx/migraphx_stream_handle.cc | 71 +-- .../migraphx/migraphx_stream_handle.h | 17 +- .../providers/shared_library/provider_api.h | 6 + .../provider_bridge_provider.cc | 8 +- onnxruntime/core/session/onnxruntime_c_api.cc | 13 +- .../core/session/provider_bridge_ort.cc | 9 +- .../core/session/provider_registration.cc | 14 +- .../python/onnxruntime_pybind_mlvalue.cc | 21 +- .../python/onnxruntime_pybind_mlvalue.h | 8 +- .../python/onnxruntime_pybind_ortvalue.cc | 44 +- .../python/onnxruntime_pybind_state.cc | 128 +---- .../python/onnxruntime_pybind_state_common.cc | 6 +- .../python/onnxruntime_pybind_state_common.h | 13 +- onnxruntime/test/util/default_providers.cc | 18 +- setup.py | 25 +- tools/ci_build/build.py | 7 +- .../nuget/generate_nuspec_for_native_nuget.py | 88 ++- 42 files changed, 1108 insertions(+), 859 deletions(-) diff --git a/.gitignore b/.gitignore index 4d0a1205b7c19..b25763334f227 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ # build, distribute, and bins (+ python proto bindings) +build.*/ build build_*/ .build_debug/* diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index b0941b4d0c922..4a3ca3dcd7741 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -29,6 +29,7 @@ include(CheckLanguage) include(CMakeDependentOption) include(FetchContent) include(CheckFunctionExists) +include(CheckSymbolExists) include(GNUInstallDirs) # onnxruntime_providers_* require CMAKE_INSTALL_* variables # TODO: update this once all system adapt c++20 diff --git a/cmake/onnxruntime_providers_migraphx.cmake b/cmake/onnxruntime_providers_migraphx.cmake index 495ff093326ad..8cb5dcf95155a 100644 --- a/cmake/onnxruntime_providers_migraphx.cmake +++ b/cmake/onnxruntime_providers_migraphx.cmake @@ -2,21 +2,11 @@ # Licensed under the MIT License. add_definitions(-DUSE_MIGRAPHX=1) - set(BUILD_LIBRARY_ONLY 1) - add_definitions("-DONNX_ML=1") - add_definitions("-DONNX_NAMESPACE=onnx") - include_directories(${protobuf_SOURCE_DIR} ${eigen_SOURCE_DIR}) - set(MIGRAPHX_ROOT ${onnxruntime_MIGRAPHX_HOME}) - include_directories(${onnx_SOURCE_DIR}) + include_directories(${protobuf_SOURCE_DIR} ${eigen_SOURCE_DIR} ${onnx_SOURCE_DIR}) set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) - if ( CMAKE_COMPILER_IS_GNUCC ) + if (CMAKE_COMPILER_IS_GNUCC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter -Wno-missing-field-initializers") endif() - set(CXX_VERSION_DEFINED TRUE) - set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS}) - if ( CMAKE_COMPILER_IS_GNUCC ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") - endif() # Add search paths for default rocm installation list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hcc /opt/rocm/hip /opt/rocm $ENV{HIP_PATH}) @@ -33,8 +23,6 @@ find_package(hip REQUIRED) find_package(migraphx REQUIRED PATHS ${AMD_MIGRAPHX_HOME}) - set(migraphx_libs migraphx::c hip::host) - file(GLOB_RECURSE onnxruntime_providers_migraphx_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.h" "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.cc" @@ -42,14 +30,14 @@ "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" ) source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_migraphx_cc_srcs}) - onnxruntime_add_shared_library_module(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs}) + onnxruntime_add_shared_library(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) - add_dependencies(onnxruntime_providers_migraphx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) - target_link_libraries(onnxruntime_providers_migraphx PRIVATE ${migraphx_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) - target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime) + add_dependencies(onnxruntime_providers_migraphx ${onnxruntime_EXTERNAL_DEPENDENCIES}) + target_link_libraries(onnxruntime_providers_migraphx PRIVATE migraphx::c hip::host ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) + target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/migraphx/onnxruntime) set_target_properties(onnxruntime_providers_migraphx PROPERTIES LINKER_LANGUAGE CXX) set_target_properties(onnxruntime_providers_migraphx PROPERTIES FOLDER "ONNXRuntime") - target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1) + target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1 ONNX_ML=1 ONNX_NAMESPACE=onnx) if(MSVC) set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS /DEF:${ONNXRUNTIME_ROOT}/core/providers/migraphx/symbols.def) target_link_libraries(onnxruntime_providers_migraphx PRIVATE ws2_32) @@ -62,6 +50,15 @@ target_link_libraries(onnxruntime_providers_migraphx PRIVATE stdc++fs) endif() + set(CMAKE_REQUIRED_LIBRARIES migraphx::c) + + check_symbol_exists(migraphx_onnx_options_set_external_data_path + "migraphx/migraphx.h" HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH) + + if(HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH) + target_compile_definitions(onnxruntime_providers_migraphx PRIVATE HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH=1) + endif() + if (onnxruntime_ENABLE_TRAINING_OPS) onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_training) target_link_libraries(onnxruntime_providers_migraphx PRIVATE onnxruntime_training) @@ -71,15 +68,39 @@ endif() if(CMAKE_SYSTEM_NAME STREQUAL "Windows") - install(TARGETS onnxruntime_providers_migraphx - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - ) - else() - install(TARGETS onnxruntime_providers_migraphx - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - ) + foreach(file migraphx-hiprtc-driver.exe migraphx.dll migraphx_c.dll migraphx_cpu.dll migraphx_device.dll migraphx_gpu.dll migraphx_onnx.dll migraphx_tf.dll) + set(_source "${AMD_MIGRAPHX_HOME}/bin/${file}") + if(EXISTS "${_source}") + add_custom_command(TARGET onnxruntime_providers_migraphx + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${_source} $) + set(_target "$/${file}") + list(APPEND _migraphx_targets ${_target}) + endif() + endforeach() + set(MIGRAPHX_LIB_FILES ${_migraphx_targets} CACHE INTERNAL "" FORCE) + install(FILES ${_migraphx_targets} + DESTINATION ${CMAKE_INSTALL_BINDIR}) + get_property(_amdhip64_location TARGET hip::amdhip64 PROPERTY IMPORTED_LOCATION_RELEASE) + cmake_path(GET _amdhip64_location PARENT_PATH _hipsdk_path) + foreach(file amd_comgr0602.dll amd_comgr0604.dll amd_comgr0700.dll hiprtc0602.dll hiprtc0604.dll hiprtc0700.dll hiprtc-builtins0602.dll hiprtc-builtins0604.dll hiprtc-builtins0700.dll) + set(_source "${_hipsdk_path}/${file}") + if(EXISTS "${_source}") + add_custom_command(TARGET onnxruntime_providers_migraphx + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${_source} $) + set(_target "$/${file}") + list(APPEND _hipsdk_targets ${_target}) + endif() + endforeach() + set(HIPSDK_LIB_FILES ${_hipsdk_targets} CACHE INTERNAL "" FORCE) + install(FILES ${_hipsdk_targets} + DESTINATION ${CMAKE_INSTALL_BINDIR}) endif() + + install(TARGETS onnxruntime_providers_migraphx + EXPORT onnxruntime_providers_migraphxTargets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index c5c85dff96411..ae976abe62fd8 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -740,6 +740,21 @@ if (onnxruntime_USE_OPENVINO) ) endif() +if (onnxruntime_USE_MIGRAPHX) + if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${MIGRAPHX_LIB_FILES} + $/onnxruntime/capi/) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${HIPSDK_LIB_FILES} + $/onnxruntime/capi/) + endif() +endif() + if (onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 36ba8db9bdc75..a1e9d06b133fd 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -607,7 +607,6 @@ endif() if(onnxruntime_USE_MIGRAPHX) list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_migraphx) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx onnxruntime_providers_shared) endif() if(onnxruntime_USE_COREML) @@ -688,9 +687,6 @@ endif() if(onnxruntime_USE_MIGRAPHX) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/migraphx/*) - list(APPEND onnxruntime_test_framework_src_patterns "${ONNXRUNTIME_ROOT}/core/providers/migraphx/migraphx_execution_provider_utils.h") - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_migraphx) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx onnxruntime_providers_shared) endif() if(onnxruntime_USE_NNAPI_BUILTIN) diff --git a/include/onnxruntime/core/common/common.h b/include/onnxruntime/core/common/common.h index adfd341451aed..820d140ccaabc 100644 --- a/include/onnxruntime/core/common/common.h +++ b/include/onnxruntime/core/common/common.h @@ -294,12 +294,26 @@ inline std::string ToUTF8String(const std::string& s) { return s; } /** * Convert a wide character string to a UTF-8 string */ -std::string ToUTF8String(const std::wstring& s); - -std::wstring ToWideString(const std::string& s); +std::string ToUTF8String(std::wstring_view s); +inline std::string ToUTF8String(const wchar_t* s) { + return ToUTF8String(std::wstring_view{s}); +} +inline std::string ToUTF8String(const std::wstring& s) { + return ToUTF8String(std::wstring_view{s}); +} +std::wstring ToWideString(std::string_view s); +inline std::wstring ToWideString(const char* s) { + return ToWideString(std::string_view{s}); +} +inline std::wstring ToWideString(const std::string& s) { + return ToWideString(std::string_view{s}); +} inline std::wstring ToWideString(const std::wstring& s) { return s; } +inline std::wstring ToWideString(std::wstring_view s) { return std::wstring{s}; } #else inline std::string ToWideString(const std::string& s) { return s; } +inline std::string ToWideString(const char* s) { return s; } +inline std::string ToWideString(std::string_view s) { return std::string{s}; } #endif constexpr size_t kMaxStrLen = 4096; diff --git a/include/onnxruntime/core/common/string_helper.h b/include/onnxruntime/core/common/string_helper.h index 1304303132d5a..c0b331cb8e9a8 100644 --- a/include/onnxruntime/core/common/string_helper.h +++ b/include/onnxruntime/core/common/string_helper.h @@ -7,5 +7,9 @@ // forward declaration struct OrtAllocator; namespace onnxruntime { -char* StrDup(const std::string& str, OrtAllocator* allocator); +char* StrDup(std::string_view str, OrtAllocator* allocator); +inline char* StrDup(const std::string& str, OrtAllocator* allocator) { + return StrDup(std::string_view{str}, allocator); +} +wchar_t* StrDup(std::wstring_view str, OrtAllocator* allocator); } // namespace onnxruntime diff --git a/include/onnxruntime/core/framework/provider_options_utils.h b/include/onnxruntime/core/framework/provider_options_utils.h index 5967fb91523d0..badb7320ea49e 100644 --- a/include/onnxruntime/core/framework/provider_options_utils.h +++ b/include/onnxruntime/core/framework/provider_options_utils.h @@ -83,12 +83,24 @@ class ProviderOptionsParser { template ProviderOptionsParser& AddValueParser( const std::string& name, ValueParserType value_parser) { + return AddValueParser(std::string_view{name}, value_parser); + } + + template + ProviderOptionsParser& AddValueParser( + std::string_view name, ValueParserType value_parser) { ORT_ENFORCE( value_parsers_.emplace(name, ValueParser{value_parser}).second, "Provider option \"", name, "\" already has a value parser."); return *this; } + template + ProviderOptionsParser& AddValueParser( + const char* name, ValueParserType value_parser) { + return AddValueParser(std::string_view{name}, value_parser); + } + /** * Adds a parser for a particular provider option value which converts a * value to the right type and assigns it to the given reference. @@ -104,13 +116,25 @@ class ProviderOptionsParser { template ProviderOptionsParser& AddAssignmentToReference( const std::string& name, ValueType& dest) { + return AddAssignmentToReference(std::string_view{name}, dest); + } + + template + ProviderOptionsParser& AddAssignmentToReference( + std::string_view name, ValueType& dest) { return AddValueParser( name, - [&dest](const std::string& value_str) -> Status { + [&dest](std::string_view value_str) -> Status { return ParseStringWithClassicLocale(value_str, dest); }); } + template + ProviderOptionsParser& AddAssignmentToReference( + const char* name, ValueType& dest) { + return AddAssignmentToReference(std::string_view{name}, dest); + } + /** * Adds a parser for a particular provider option value which maps an * enumeration name to a value and assigns it to the given reference. @@ -128,6 +152,12 @@ class ProviderOptionsParser { template ProviderOptionsParser& AddAssignmentToEnumReference( const std::string& name, const EnumNameMapping& mapping, EnumType& dest) { + return AddAssignmentToEnumReference(std::string_view{name}, mapping, dest); + } + + template + ProviderOptionsParser& AddAssignmentToEnumReference( + std::string_view name, const EnumNameMapping& mapping, EnumType& dest) { return AddValueParser( name, [&mapping, &dest](const std::string& value_str) -> Status { @@ -135,6 +165,12 @@ class ProviderOptionsParser { }); } + template + ProviderOptionsParser& AddAssignmentToEnumReference( + const char* name, const EnumNameMapping& mapping, EnumType& dest) { + return AddAssignmentToEnumReference(std::string_view{name}, mapping, dest); + } + /** * Parses the given provider options. */ diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 2899a219bdda0..6eb15280a4aa4 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -754,13 +754,13 @@ typedef struct OrtMIGraphXProviderOptions { int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true int migraphx_fp8_enable; // MIGraphX FP8 precision. Default 0 = false, nonzero = true int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true - int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true + int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, nonzero = true const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name - int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, noznero = true + int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, nonzero = true const char* migraphx_save_model_path; // migraphx model path name - int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true + int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, nonzero = true const char* migraphx_load_model_path; // migraphx model path name - bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false + bool migraphx_exhaustive_tune; // MIGraphX tuned compile. Default = false, nonzero = true /** \brief MIGraphX memory limit (To use all possible memory pass in maximum size_t) * Defaults to SIZE_MAX. @@ -776,6 +776,7 @@ typedef struct OrtMIGraphXProviderOptions { */ int migraphx_arena_extend_strategy; + // This is the legacy struct and don't add new fields here. } OrtMIGraphXProviderOptions; /** \brief OpenVINO Provider Options diff --git a/onnxruntime/core/common/helper.cc b/onnxruntime/core/common/helper.cc index 6a52db73df106..07cd1672b27c1 100644 --- a/onnxruntime/core/common/helper.cc +++ b/onnxruntime/core/common/helper.cc @@ -18,7 +18,7 @@ namespace onnxruntime { #ifdef _WIN32 -std::string ToUTF8String(const std::wstring& s) { +std::string ToUTF8String(std::wstring_view s) { if (s.size() >= static_cast(std::numeric_limits::max())) ORT_THROW("length overflow"); @@ -33,7 +33,7 @@ std::string ToUTF8String(const std::wstring& s) { return ret; } -std::wstring ToWideString(const std::string& s) { +std::wstring ToWideString(std::string_view s) { if (s.size() >= static_cast(std::numeric_limits::max())) ORT_THROW("length overflow"); diff --git a/onnxruntime/core/common/path_string.h b/onnxruntime/core/common/path_string.h index 6cfb327cce08a..4ca326d76a37d 100644 --- a/onnxruntime/core/common/path_string.h +++ b/onnxruntime/core/common/path_string.h @@ -40,6 +40,12 @@ inline PathString ToPathString(const PathString& s) { static_assert(std::is_same::value, "PathString is not std::wstring!"); +inline PathString ToPathString(std::string_view s) { + return ToWideString(s); +} +inline PathString ToPathString(const char* s) { + return ToWideString(s); +} inline PathString ToPathString(const std::string& s) { return ToWideString(s); } @@ -56,6 +62,14 @@ inline std::string PathToUTF8String(const PathString& s) { static_assert(std::is_same::value, "PathString is not std::string!"); +inline PathString ToPathString(const char* s) { + return s; +} + +inline PathString ToPathString(std::string_view s) { + return PathString{s}; +} + inline PathChar ToLowerPathChar(PathChar c) { return std::tolower(c); } diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.h b/onnxruntime/core/providers/migraphx/gpu_data_transfer.h index 5918716b3e77f..a4eb8efd2afea 100644 --- a/onnxruntime/core/providers/migraphx/gpu_data_transfer.h +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.h @@ -3,7 +3,7 @@ #pragma once -#include "migraphx_inc.h" +#include "core/providers/migraphx/migraphx_inc.h" #include "core/framework/data_transfer.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc index 1cac133ab0c2c..911a1a7fd18b9 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc @@ -2,12 +2,11 @@ // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" -#include "migraphx_call.h" -#include "migraphx_allocator.h" +#include "core/providers/migraphx/migraphx_call.h" +#include "core/providers/migraphx/migraphx_allocator.h" #include "core/common/status.h" #include "core/framework/float16.h" -#include "core/common/status.h" -#include "gpu_data_transfer.h" +#include "core/providers/migraphx/gpu_data_transfer.h" namespace onnxruntime { @@ -55,7 +54,9 @@ void MIGraphXExternalAllocator::Free(void* p) { auto it = reserved_.find(p); if (it != reserved_.end()) { reserved_.erase(it); - if (empty_cache_) empty_cache_(); + if (empty_cache_ != nullptr) { + empty_cache_(); + } } } diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.h b/onnxruntime/core/providers/migraphx/migraphx_allocator.h index f6b7788e0604c..10e06ab2f35ad 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.h +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.h @@ -3,9 +3,9 @@ #pragma once +#include #include #include "core/framework/allocator.h" -#include namespace onnxruntime { diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.cc b/onnxruntime/core/providers/migraphx/migraphx_call.cc index 9807cd646e51c..79dfb5512d3b5 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_call.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_call.cc @@ -1,13 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include + #ifdef _WIN32 #include #else #include #endif -#include #include "core/common/common.h" #include "core/common/status.h" #include "core/providers/shared_library/provider_api.h" @@ -15,10 +17,9 @@ namespace onnxruntime { -using namespace common; - +namespace { template -const char* RocmErrString(ERRTYPE x) { +std::string_view RocmErrString(ERRTYPE x) { ORT_NOT_IMPLEMENTED(); } @@ -27,14 +28,16 @@ const char* RocmErrString(ERRTYPE x) { return #x template <> -const char* RocmErrString(hipError_t x) { +std::string_view RocmErrString(hipError_t x) { (void)hipDeviceSynchronize(); - return hipGetErrorString(x); + return std::string_view{hipGetErrorString(x)}; } +} // namespace + template std::conditional_t RocmCall( - ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line) { + ERRTYPE retCode, std::string_view exprString, std::string_view libName, ERRTYPE successCode, std::string_view msg, std::string_view file, const int line) { if (retCode != successCode) { try { #ifdef _WIN32 @@ -47,17 +50,16 @@ std::conditional_t RocmCall( int currentHipDevice; (void)hipGetDevice(¤tHipDevice); (void)hipGetLastError(); // clear last HIP error - static char str[1024]; - snprintf(str, 1024, "%s failure %d: %s ; GPU=%d ; hostname=%s ; file=%s ; line=%d ; expr=%s; %s", - libName, (int)retCode, RocmErrString(retCode), currentHipDevice, - hostname.c_str(), - file, line, exprString, msg); + std::stringstream ss; + ss << libName << " failure " << static_cast(retCode) << ": " << RocmErrString(retCode) + << "; GPU=" << currentHipDevice << "; hostname=" << hostname << "; file=" << file << "; line=" << line + << "; expr=" << exprString << "; " << msg; if constexpr (THRW) { // throw an exception with the error info - ORT_THROW(str); + ORT_THROW(ss.str()); } else { - LOGS_DEFAULT(ERROR) << str; - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, str); + LOGS_DEFAULT(ERROR) << ss.str(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ss.str()); } } catch (const std::exception& e) { // catch, log, and rethrow since HIP code sometimes hangs in destruction, so we'd never get to see the error if constexpr (THRW) { @@ -73,7 +75,7 @@ std::conditional_t RocmCall( } } -template Status RocmCall(hipError_t retCode, const char* exprString, const char* libName, hipError_t successCode, const char* msg, const char* file, const int line); -template void RocmCall(hipError_t retCode, const char* exprString, const char* libName, hipError_t successCode, const char* msg, const char* file, const int line); +template Status RocmCall(hipError_t retCode, std::string_view exprString, std::string_view libName, hipError_t successCode, std::string_view msg, std::string_view file, int line); +template void RocmCall(hipError_t retCode, std::string_view exprString, std::string_view libName, hipError_t successCode, std::string_view msg, std::string_view file, int line); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.h b/onnxruntime/core/providers/migraphx/migraphx_call.h index 6d514e01aea96..9c3b5c79a947b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_call.h +++ b/onnxruntime/core/providers/migraphx/migraphx_call.h @@ -2,7 +2,7 @@ // Licensed under the MIT License. #pragma once -#include "migraphx_inc.h" +#include "core/providers/migraphx/migraphx_inc.h" #include "core/common/common.h" namespace onnxruntime { @@ -13,7 +13,7 @@ namespace onnxruntime { template std::conditional_t RocmCall( - ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line); + ERRTYPE retCode, std::string_view exprString, std::string_view libName, ERRTYPE successCode, std::string_view msg, std::string_view file, int line); #define HIP_CALL(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) #define HIP_CALL_THROW(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 41b55e3baf508..a59347841be95 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1,26 +1,34 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License -#include + +#include + #include +#include +#include +#include #include -#include +#include +#include #include -#include +#include +#include +#include +#include +#include #include "core/providers/shared_library/provider_api.h" #define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" #include "core/common/safeint.h" #include "core/common/logging/severity.h" -#include "migraphx_execution_provider.h" -#include "migraphx_execution_provider_info.h" -#include "migraphx_execution_provider_utils.h" -#include "migraphx_allocator.h" -#include "gpu_data_transfer.h" -#include -#include "migraphx_call.h" - -#include "migraphx_stream_handle.h" +#include "core/providers/migraphx/migraphx_execution_provider.h" +#include "core/providers/migraphx/migraphx_execution_provider_info.h" +#include "core/providers/migraphx/migraphx_execution_provider_utils.h" +#include "core/providers/migraphx/migraphx_allocator.h" +#include "core/providers/migraphx/gpu_data_transfer.h" +#include "core/providers/migraphx/migraphx_call.h" +#include "core/providers/migraphx/migraphx_stream_handle.h" #if defined(_MSC_VER) #pragma warning(disable : 4244 4245) @@ -105,240 +113,144 @@ std::shared_ptr MIGraphXExecutionProvider::GetKernelRegistry() c return s_kernel_registry; } -MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, - info.device_id)}, - info_(info) { - InitProviderOrtApi(); - get_flags_from_session_info(info); - metadef_id_generator_ = ModelMetadefIdGenerator::Create(); - get_flags_from_env(); +static std::string_view GetArenaExtendStrategyName(ArenaExtendStrategy strategy) { + switch (strategy) { + case ArenaExtendStrategy::kNextPowerOfTwo: + return "kNextPowerOfTwo"; + case ArenaExtendStrategy::kSameAsRequested: + return "kSameAsRequested"; + default: + return "Unknown"; + } } -MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { -} +#define GET_ENV(variable, value, ...) \ + const auto value##env{GetEnvironmentVar(variable)}; \ + if (!value##env.empty()) { \ + __VA_ARGS__; \ + LOGS_DEFAULT(INFO) << "\n " << variable << ": " << value##env; \ + } -void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info) { - // Set GPU device to be used - HIP_CALL_THROW(hipSetDevice(info_.device_id)); - t_ = migraphx::target(info.target_device.c_str()); +#define GET_ENV_BOOL(variable, value) \ + GET_ENV(variable, value, value = std::stoi(value##env) != 0) - // Quantization - fp16_enable_ = info.fp16_enable; +#define GET_ENV_STRING(variable, value) \ + GET_ENV(variable, value, value = value##env) +MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) + : IExecutionProvider{kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, info.device_id)}, + device_id_{info.device_id}, + fp16_enable_{info.fp16_enable}, +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && (HIP_VERSION_MINOR > 4 || (HIP_VERSION_MINOR == 4 && HIP_VERSION_PATCH >= 2))) + bf16_enable_{info.bf16_enable}, +#endif #if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4) - fp8_enable_ = info.fp8_enable; -#else - LOGS_DEFAULT(WARNING) << "MIGraphX: FP8 Quantization requires ROCm 6.4 or greater"; - fp8_enable_ = false; + fp8_enable_{info.fp8_enable}, #endif - int8_enable_ = info.int8_enable; - - if (int8_enable_ and fp8_enable_) { - LOGS_DEFAULT(FATAL) << "MIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; - } - - if (int8_enable_ xor fp8_enable_) { - int8_calibration_cache_name_ = info.int8_calibration_table_name; - int8_use_native_migraphx_calibration_table_ = info.int8_use_native_calibration_table; - } - - if (int8_enable_ or fp8_enable_) { - int8_calibration_cache_available_ = !info.int8_calibration_table_name.empty(); - } + int8_enable_{info.int8_enable}, + model_cache_path_{info.model_cache_dir}, + t_{info.target_device.c_str()}, + exhaustive_tune_{info.exhaustive_tune}, + metadef_id_generator_{ModelMetadefIdGenerator::Create()}, + external_alloc_{info.external_alloc}, + external_free_{info.external_free}, + external_empty_cache_{info.external_empty_cache} { + InitProviderOrtApi(); - // Load INT8 calibration table - std::unordered_map dynamic_range_map; - if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { - const std::string calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); - if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { - throw std::runtime_error("Session Failed to read INT8 calibration table " + calibration_cache_path); - } - } + // Set GPU device to be used and read device properties for feature usage. - // Save/load migraphx compiled models - save_compiled_model_ = info.save_compiled_model; - save_compiled_path_ = info.save_model_file; - load_compiled_model_ = info.load_compiled_model; - load_compiled_path_ = info.load_model_file; + HIP_CALL_THROW(hipSetDevice(device_id_)); + HIP_CALL_THROW(hipGetDeviceProperties(&device_prop_, device_id_)); - exhaustive_tune_ = info.exhaustive_tune; + // Overwrite initialized values with values from environment variables. - LOGS_DEFAULT(WARNING) << "[MIGraphX EP] MIGraphX provider Session Options:"; - print_migraphx_ep_flags(); -} + LOGS_DEFAULT(WARNING) << "[MIGraphX EP] MIGraphX ENV Override Variables Set:"; + GET_ENV_BOOL(migraphx_env_vars::kFP16Enable, fp16_enable_); +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && (HIP_VERSION_MINOR > 4 || (HIP_VERSION_MINOR == 4 && HIP_VERSION_PATCH >= 2))) + GET_ENV_BOOL(migraphx_env_vars::kBF16Enable, bf16_enable_); +#endif +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4) + GET_ENV_BOOL(migraphx_env_vars::kFP8Enable, fp8_enable_); +#endif + GET_ENV_BOOL(migraphx_env_vars::kINT8Enable, int8_enable_); + GET_ENV(migraphx_env_vars::kINT8CalibrationTableName, int8_calibration_cache_name_); + GET_ENV(migraphx_env_vars::kINT8UseNativeMIGraphXCalibrationTable, int8_use_native_migraphx_calibration_table_); + GET_ENV_STRING(migraphx_env_vars::kCachePath, calibration_cache_path_); + GET_ENV_STRING(migraphx_env_vars::kModelCachePath, model_cache_path_); + GET_ENV_BOOL(migraphx_env_vars::kDumpModelOps, dump_model_ops_); + GET_ENV_BOOL(migraphx_env_vars::kExhaustiveTune, exhaustive_tune_); + + // Verify configuration correctness and adjust accordingly. + +#if HIP_VERSION_MAJOR < 6 || (HIP_VERSION_MAJOR == 6 && (HIP_VERSION_MINOR < 4 || (HIP_VERSION_MINOR == 4 && HIP_VERSION_PATCH < 2))) + LOGS_DEFAULT(WARNING) << "MIGraphX: BF16 Quantization requires ROCm 6.4.2 or greater"; + bf16_enable_ = false; +#endif -void MIGraphXExecutionProvider::get_flags_from_env() { - LOGS_DEFAULT(WARNING) << "\n[MIGraphX EP] MIGraphX ENV Override Variables Set:"; - // whether fp16 is enable - const std::string fp16_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP16Enable); - if (!fp16_enable_env.empty()) { - fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP16_ENABLE: " << fp16_enable_; + if (bf16_enable_ && fp16_enable_) { + bf16_enable_ = false; + fp16_enable_ = false; + LOGS_DEFAULT(FATAL) << "MIGraphX: BF16 and FP16 Quantization Mutually exclusive. Ignoring both Quantization flags"; } - // whether fp8 quantization is enabled - const std::string fp8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP8Enable); - if (!fp8_enable_env.empty()) { -#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4) - fp8_enable_ = (std::stoi(fp8_enable_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP8_ENABLE: " << fp8_enable_; -#else - LOGS_DEFAULT(WARNING) << "MIGraphX: FP8 Quantization requires ROCm 6.4 or greater"; - fp8_enable = false; +#if HIP_VERSION_MAJOR < 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR < 4) + LOGS_DEFAULT(WARNING) << "MIGraphX: FP8 Quantization requires ROCm 6.4 or greater"; + fp8_enable_ = false; #endif - } - // whether int8 is enabled - const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8Enable); - if (!int8_enable_env.empty()) { - int8_enable_ = (std::stoi(int8_enable_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_ENABLE: " << int8_enable_; + if (int8_enable_ && fp8_enable_) { + LOGS_DEFAULT(FATAL) << "MIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; } - if (int8_enable_ and fp8_enable_) { - LOGS_DEFAULT(FATAL) << "\nMIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; + if (int8_enable_ ^ fp8_enable_) { + int8_calibration_table_name_ = + int8_calibration_cache_name_env.empty() ? info.int8_calibration_table_name : int8_calibration_cache_name_env; + int8_use_native_calibration_table_ = + int8_use_native_migraphx_calibration_table_env.empty() ? info.int8_use_native_calibration_table : std::stoi(int8_use_native_migraphx_calibration_table_env) != 0; } if (int8_enable_ || fp8_enable_) { - const std::string int8_calibration_cache_name_env = - onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8CalibrationTableName); - if (!int8_calibration_cache_name_env.empty()) { - int8_calibration_cache_name_ = int8_calibration_cache_name_env; - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CALIBRATION_TABLE_NAME: " << int8_calibration_cache_name_; - } - - const std::string cache_path = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kCachePath); - if (!cache_path.empty()) { - calibration_cache_path_ = cache_path; - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CACHE_PATH: " << calibration_cache_path_; - } - - const std::string int8_use_native_migraphx_calibration_table_env = - onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8UseNativeMIGraphXCalibrationTable); - if (!int8_use_native_migraphx_calibration_table_env.empty()) { - int8_use_native_migraphx_calibration_table_ = - (std::stoi(int8_use_native_migraphx_calibration_table_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE: " - << int8_use_native_migraphx_calibration_table_; - } - } - - if (int8_enable_ or fp8_enable_) { - int8_calibration_cache_available_ = !int8_calibration_cache_name_.empty(); + int8_calibration_cache_available_ = !info.int8_calibration_table_name.empty(); } // Load INT8 calibration table - std::unordered_map dynamic_range_map; if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { - const std::string calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); - if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { - throw std::runtime_error("ENV Failed to read calibration table " + calibration_cache_path); + std::unordered_map dynamic_range_map; + auto calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_table_name_); + if (!ReadDynamicRange(calibration_cache_path, int8_use_native_calibration_table_, dynamic_range_map)) { + throw std::runtime_error("Session Failed to read INT8 calibration table " + calibration_cache_path.string()); } } - // Save/load migraphx compiled models - const std::string save_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSaveCompiledModel); - if (!save_comp_model_env.empty()) { - save_compiled_model_ = (std::stoi(save_comp_model_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_SAVE_COMPILED_MODEL: " << save_compiled_model_; - } - - const std::string save_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSavedModelPath); - if (save_compiled_model_ && !save_model_path_env.empty()) { - save_compiled_path_ = save_model_path_env; - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_SAVE_COMPILED_PATH: " << save_compiled_path_; - } - - const std::string load_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadCompiledModel); - if (!load_comp_model_env.empty()) { - load_compiled_model_ = (std::stoi(load_comp_model_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_LOAD_COMPILED_MODEL: " << load_compiled_model_; - } - - const std::string load_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadModelPath); - if (load_compiled_model_ && !load_model_path_env.empty()) { - load_compiled_path_ = load_model_path_env; - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_LOAD_COMPILED_PATH: " << load_compiled_path_; - } - - // dump unsupported ops - const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::dumpModelOps); - if (!dump_model_ops_env.empty()) { - dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_DUMP_MODEL_OPS: " << dump_model_ops_; - } + // Print configured options for the session. - // Allow for exhaustive tune during compile - const std::string exhaustive_tune_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kExhaustiveTune); - if (!exhaustive_tune_env.empty()) { - exhaustive_tune_ = (std::stoi(exhaustive_tune_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_EXHAUSTIVE_TUNE_OPS: " << exhaustive_tune_; - } -} - -void MIGraphXExecutionProvider::print_migraphx_ep_flags() { - LOGS_DEFAULT(WARNING) << "\n device_id: " << info_.device_id - << "\n migraphx_fp16_enable: " << fp16_enable_ - << "\n migraphx_fp8_enable: " << fp8_enable_ - << "\n migraphx_int8_enable: " << int8_enable_ + LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider Session Options:" + << "\n " << migraphx_provider_option::kDeviceId << ": " << device_id_ + << "\n " << migraphx_provider_option::kFp16Enable << ": " << fp16_enable_ + << "\n " << migraphx_provider_option::kBf16Enable << ": " << bf16_enable_ + << "\n " << migraphx_provider_option::kFp8Enable << ": " << fp8_enable_ + << "\n " << migraphx_provider_option::kInt8Enable << ": " << int8_enable_ + << "\n " << migraphx_provider_option::kMemLimit << ": " << mem_limit_ + << "\n " << migraphx_provider_option::kArenaExtendStrategy << ": " << GetArenaExtendStrategyName(arena_extend_strategy_) << "\n dump_model_ops: " << dump_model_ops_ - << "\n exhaustive_tune: " << exhaustive_tune_ - << "\n migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_ + << "\n " << migraphx_provider_option::kExhaustiveTune << ": " << exhaustive_tune_ + << "\n " << migraphx_provider_option::kInt8CalibTable << ": " << int8_calibration_table_name_ << "\n int8_calibration_cache_available: " << int8_calibration_cache_available_ - << "\n use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_ - << "\n migraphx_save_compiled_model: " << save_compiled_model_ - << "\n migraphx_save_compiled_model_path: " << save_compiled_path_ - << "\n migraphx_load_compiled_model: " << load_compiled_model_ - << "\n migraphx_load_compiled_model_path: " << load_compiled_path_; -} - -AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, - size_t migx_mem_limit, - ArenaExtendStrategy arena_extend_strategy, - MIGraphXExecutionProviderExternalAllocatorInfo - external_allocator_info, - const OrtArenaCfg* default_memory_arena_cfg) { - if (external_allocator_info.UseExternalAllocator()) { - AllocatorCreationInfo default_memory_info( - [external_allocator_info](OrtDevice::DeviceId id) { - return std::make_unique(id, HIP, - external_allocator_info.alloc, - external_allocator_info.free, - external_allocator_info.empty_cache); - }, - device_id, - false); - - return CreateAllocator(default_memory_info); - } else { - AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId id) { - return std::make_unique(id, HIP); - }, - device_id, - true, - {default_memory_arena_cfg ? *default_memory_arena_cfg - : OrtArenaCfg(migx_mem_limit, static_cast(arena_extend_strategy), - -1, -1, -1, -1L)}, - // make it stream aware - true); - - // ROCM malloc/free is expensive so always use an arena - return CreateAllocator(default_memory_info); - } + << "\n " << migraphx_provider_option::kInt8UseNativeCalibTable << ": " << int8_use_native_calibration_table_ + << "\n " << migraphx_provider_option::kModelCacheDir << ": " << model_cache_path_; } std::vector MIGraphXExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId device_id) { return std::make_unique(device_id, onnxruntime::CUDA); }, - info_.device_id); + [](OrtDevice::DeviceId device_id) { + return std::make_unique(device_id, onnxruntime::CUDA); + }, + device_id_); AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { - return std::make_unique(device_id, onnxruntime::CUDA_PINNED); + return std::make_unique(device_id, CUDA_PINNED); }, - info_.device_id); + device_id_); return std::vector{CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)}; } @@ -354,6 +266,7 @@ static bool IsTypeSupported(const NodeArg* node_arg) { switch (type_proto->tensor_type().elem_type()) { case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FNUZ: @@ -384,6 +297,9 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: mgx_type = migraphx_shape_half_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + mgx_type = migraphx_shape_bf16_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: mgx_type = migraphx_shape_float_type; break; @@ -457,7 +373,7 @@ std::vector toVector(const ONNX_NAMESPACE::int64s& nums) { static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, const Node* node) { std::vector input_nodes; const auto& optype = node->OpType(); - if (optype == "ArgMax" or optype == "ArgMin") { + if (optype == "ArgMax" || optype == "ArgMin") { const auto& attributes = node->GetAttributes(); // we do not support select_last_index = 1 for now auto sli_attr = attributes.find("select_last_index"); @@ -475,7 +391,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return true; } - if ((input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) and + if ((input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) && (input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)) { return true; } @@ -503,7 +419,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co // storage order 1 (column major format) is not supported auto storage_order_attr = attributes.find("storage_order"); - if (storage_order_attr != attributes.end() and (*storage_order_attr).second.i() != 0) { + if (storage_order_attr != attributes.end() && (*storage_order_attr).second.i() != 0) { return true; } @@ -513,7 +429,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return true; } auto data_type = input_type->tensor_type().elem_type(); - if (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8 or + if (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8 || data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8) { return true; } @@ -524,7 +440,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return true; } - if ((input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) and + if ((input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) && (input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)) { return true; } @@ -580,7 +496,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co } return true; } - } else if (optype == "Resize" or optype == "Upsample") { + } else if (optype == "Resize" || optype == "Upsample") { const auto& attributes = node->GetAttributes(); auto ct_attr = attributes.find("coordinate_transformation_mode"); if (ct_attr != attributes.end()) { @@ -618,7 +534,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co } const auto& attributes = node->GetAttributes(); - if (attributes.count("starts") > 0 and attributes.count("ends") > 0) { + if (attributes.count("starts") > 0 && attributes.count("ends") > 0) { auto starts = toVector((*attributes.find("starts")).second.ints()); auto ends = toVector((*attributes.find("ends")).second.ints()); for (std::size_t i = 0; i < starts.size(); ++i) { @@ -656,7 +572,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { return true; } - } else if (optype == "Unsqueeze" or optype == "Squeeze") { + } else if (optype == "Unsqueeze" || optype == "Squeeze") { const auto& args = node->InputDefs(); if (args.size() == 2) { if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { @@ -685,9 +601,9 @@ void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::v if (args.size() == 2) { std::vector node_inputs; if (canEvalNodeArgument(graph_viewer, node, {1}, node_inputs)) { - return (not std::all_of(node_inputs.begin(), node_inputs.end(), [&](auto index) { - return std::find(git.begin(), git.end(), index) != git.end(); - })); + return !std::all_of(node_inputs.begin(), node_inputs.end(), [&](auto i) { + return std::find(git.begin(), git.end(), i) != git.end(); + }); } else { return true; } @@ -857,12 +773,14 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st erased.insert(output); } // Only when output is neither in input list nor erased list, add the output to output list - else if (erased.find(output) == erased.end()) { - if (std::find(graph_output_names.begin(), - graph_output_names.end(), output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; + else { + if (erased.find(output) == erased.end()) { + if (std::find(graph_output_names.begin(), + graph_output_names.end(), output->Name()) != graph_output_names.end()) { + graph_outputs_to_add[output] = output_order; + } + fused_outputs[output] = output_order++; } - fused_outputs[output] = output_order++; } } } @@ -944,6 +862,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "Atan", "Atanh", "ATen", + "Attention", "AveragePool", "BatchNormalization", "BiasGelu", @@ -986,6 +905,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "Greater", "GreaterOrEqual", "GroupNormalization", + "GroupNorm", "GroupQueryAttention", "HardSigmoid", "HardSwish", @@ -1017,6 +937,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "MultiHeadAttention", "Neg", "NegativeLogLikelihoodLoss", + "NhwcConv", "NonMaxSuppression", "NonZero", "Not", @@ -1053,6 +974,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "ReverseSequence", "RNN", "Roialign", + "RotaryEmbedding", "Round", "Scatter", "ScatterElements", @@ -1243,29 +1165,25 @@ bool get_input_output_names(const GraphViewer& graph, // Attempt to load a model and catch any exceptions on load fail. // Useful to default to EP to trigger the compile if file doesn't exist or loading fails. -bool load_precompiled_model(migraphx::program& prog, bool load_enable, std::string path) { - try { - if (load_enable) { - LOGS_DEFAULT(WARNING) << "Attempting to load model at:" << path; - prog = migraphx::load(path.c_str()); - LOGS_DEFAULT(WARNING) << "load model : Success"; - return true; - } else { - return false; - } - } catch (...) { - return false; +bool load_precompiled_model(migraphx::program& prog, const std::filesystem::path& path) try { + if (!path.empty() && exists(path)) { + LOGS_DEFAULT(VERBOSE) << "Attempting to load model at:" << path.string(); + prog = migraphx::load(path.string().c_str()); + LOGS_DEFAULT(VERBOSE) << "load model : Success"; + return true; } return false; +} catch (...) { + return false; } -void save_compiled_model(migraphx::program& prog, bool save_enable, std::string out_path) { - if (save_enable) { - LOGS_DEFAULT(WARNING) << "Model Save at " << out_path << ": Begin"; +void save_compiled_model(const migraphx::program& prog, const std::filesystem::path& path) { + if (!path.empty()) { + LOGS_DEFAULT(VERBOSE) << "Model Save at " << path.string() << ": Begin"; migraphx::file_options fo; fo.set_file_format("msgpack"); - migraphx::save(prog, out_path.c_str(), fo); - LOGS_DEFAULT(WARNING) << "Model Save: Complete"; + save(prog, path.string().c_str(), fo); + LOGS_DEFAULT(VERBOSE) << "Model Save: Complete"; } } @@ -1275,12 +1193,13 @@ void calibrate_and_quantize(migraphx::program& prog, const migraphx::target& t, const migraphx::program_parameters quant_params, bool fp16_enable, + bool bf16_enable, bool int8_enable, bool fp8_enable, bool int8_calibration_cache_available, std::unordered_map& dynamic_range_map) { // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if ((int8_enable xor fp8_enable) && int8_calibration_cache_available) { + if ((int8_enable ^ fp8_enable) && int8_calibration_cache_available) { LOGS_DEFAULT(WARNING) << "Quantizing input program"; auto param_shapes = prog.get_parameter_shapes(); @@ -1317,6 +1236,14 @@ void calibrate_and_quantize(migraphx::program& prog, migraphx::quantize_fp16(prog); LOGS_DEFAULT(WARNING) << "Quantizing fp16: Complete"; } + +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4 && HIP_VERSION_PATCH >= 2) + if (bf16_enable) { + LOGS_DEFAULT(WARNING) << "Quantizing input program to bf16"; + migraphx::quantize_bf16(prog); + LOGS_DEFAULT(WARNING) << "Quantizing bf16: Complete"; + } +#endif } void compile_program(migraphx::program& prog, @@ -1330,6 +1257,27 @@ void compile_program(migraphx::program& prog, LOGS_DEFAULT(WARNING) << "Model Compile: Complete"; } +std::string to_hex(const uint64_t v) { + std::array s{}; + auto [ptr, _] = std::to_chars(s.data(), s.data() + s.size(), v, 16); + return std::string{s.data(), ptr}; +} + +template +std::string make_hash(T v) { + std::array temp{}; + MurmurHash3::x86_128(v.data(), gsl::narrow_cast(v.size()), temp[0], temp.data()); + return to_hex(temp[0] | static_cast(temp[1]) << 32); +} + +template <> +std::string make_hash(const char* v) { + return make_hash(std::string_view{v}); +} + +constexpr std::uint64_t MIGraphX_Version = + ((MIGRAPHX_VERSION_MAJOR << 16) | (MIGRAPHX_VERSION_MINOR << 8) | MIGRAPHX_VERSION_PATCH); + Status MIGraphXExecutionProvider::Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) { migraphx::onnx_options options; @@ -1337,6 +1285,33 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& for (const auto& fused_node_graph : fused_nodes) { const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; const Node& fused_node = fused_node_graph.fused_node; + + std::filesystem::path model_cache_file; + auto mxr_filename_prefix = to_hex(MIGraphX_Version) + "-" + GenerateGraphId(graph_body_viewer) + "-" + make_hash(std::string_view(device_prop_.gcnArchName)) + "-"; + + // Get model input names (only first layer) + const Graph* cur_graph = &graph_body_viewer.GetGraph(); + while (cur_graph->IsSubgraph()) { + cur_graph = cur_graph->ParentGraph(); + } + const Graph& main_graph = *cur_graph; + const auto& input_tensor = main_graph.GetInputs(); + for (auto i : input_tensor) { + session_input_names.insert(i->Name()); + } + + // empty cache path means the MXR caching is disabled - always compile + if (!model_cache_path_.empty()) { + std::vector input_shapes; + for (std::size_t i = 0; i < session_input_names.size(); ++i) { + auto tensor_shape = input_tensor[i]->Shape(); + for (int j = 1; j < tensor_shape->dim_size(); ++j) { + input_shapes.push_back(tensor_shape->dim(j).dim_value()); + } + } + model_cache_file = model_cache_path_ / (mxr_filename_prefix + make_hash(input_shapes) + ".mxr"); + } + // map parameter input name to index std::unordered_map input_name_index; const auto& input_defs = fused_node.InputDefs(); @@ -1367,15 +1342,20 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::program prog; if (!no_input_shape) { - if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { - LOGS_DEFAULT(INFO) << "No input shapes detected quantizing model"; + if (!load_precompiled_model(prog, model_cache_file)) { + LOGS_DEFAULT(VERBOSE) << "No input shapes detected quantizing model"; +#ifndef ENABLE_TRAINING_CORE +#ifdef HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH + options.set_external_data_path(model_path_.parent_path().string()); +#endif +#endif prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); migraphx::program_parameters quant_params; - calibrate_and_quantize(prog, t_, quant_params, fp16_enable_, int8_enable_, + calibrate_and_quantize(prog, t_, quant_params, fp16_enable_, bf16_enable_, int8_enable_, fp8_enable_, int8_calibration_cache_available_, dynamic_range_map_); compile_program(prog, t_, exhaustive_tune_); - save_compiled_model(prog, save_compiled_model_, save_compiled_path_); + save_compiled_model(prog, model_cache_file); } auto prog_output_shapes = prog.get_output_shapes(); @@ -1396,10 +1376,9 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::unique_ptr p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, - map_no_input_shape_[context->node_name], fp16_enable_, fp8_enable_, int8_enable_, + map_no_input_shape_[context->node_name], fp16_enable_, bf16_enable_, fp8_enable_, int8_enable_, int8_calibration_cache_available_, dynamic_range_map_, - save_compiled_model_, save_compiled_path_, - load_compiled_model_, load_compiled_path_, dump_model_ops_}; + model_cache_path_.string(), dump_model_ops_}; *state = p.release(); return 0; }; @@ -1409,7 +1388,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& delete static_cast(state); }; - compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) { + compute_info.compute_func = [this, mxr_filename_prefix](FunctionState state, const OrtApi* api, OrtKernelContext* context) { Ort::KernelContext ctx(context); MIGraphXFuncState* mgx_state = reinterpret_cast(state); @@ -1421,6 +1400,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::onnx_options& cmp_options = mgx_state->options; bool& no_input_shape = mgx_state->no_input_shape; bool fp16_enable = mgx_state->fp16_enable; + bool bf16_enable = mgx_state->bf16_enable; bool fp8_enable = mgx_state->fp8_enable; bool int8_enable = mgx_state->int8_enable; bool int8_calibration_cache_available = mgx_state->int8_calibration_cache_available; @@ -1429,8 +1409,10 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // from input data bool input_shape_match = true; migraphx::program_parameter_shapes param_shapes; + std::vector input_shapes; + if (no_input_shape) { - LOGS_DEFAULT(INFO) << "Missing input shape setting input parameters again"; + LOGS_DEFAULT(VERBOSE) << "Missing input shape setting input parameters again"; for (auto& it : map_input_name_index) { auto& name = it.first; auto& index = it.second; @@ -1442,7 +1424,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& input_shape_match = false; } } else { - LOGS_DEFAULT(INFO) << "Assigning inputs, and parameters from compiled model"; + LOGS_DEFAULT(VERBOSE) << "Assigning inputs, and parameters from compiled model"; param_shapes = prog.get_parameter_shapes(); auto prog_output_shapes = prog.get_output_shapes(); @@ -1459,8 +1441,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& auto mgx_s = param_shapes[name]; auto mgx_lens = mgx_s.lengths(); auto mgx_strides = mgx_s.strides(); - if (mgx_lens.size() == 1 and mgx_lens[0] == 1 and - mgx_strides.size() == 1 and mgx_strides[0] == 0) { + if (mgx_lens.size() == 1 && mgx_lens[0] == 1 && + mgx_strides.size() == 1 && mgx_strides[0] == 0) { mgx_lens.clear(); } @@ -1468,6 +1450,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& cmp_options.set_input_parameter_shape(name, ort_lens); input_shape_match = false; } + input_shapes.insert(input_shapes.end(), tensor_shape.begin(), tensor_shape.end()); } } } @@ -1476,20 +1459,25 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // input shapes are different, needs to re-parse onnx and // re-compile the program if (!input_shape_match) { - if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { - LOGS_DEFAULT(VERBOSE) << "Input shape mismatch detected. Recompiling" << std::endl; + std::filesystem::path model_cache_file; + // empty cache path means the MXR caching is disabled - always compile + if (!model_cache_path_.empty()) { + model_cache_file = mgx_state->model_cache_dir / (mxr_filename_prefix + make_hash(input_shapes) + ".mxr"); + } + if (!load_precompiled_model(prog, model_cache_file)) { + LOGS_DEFAULT(VERBOSE) << "Input shape mismatch detected. Recompiling"; #ifndef ENABLE_TRAINING_CORE -#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 2) - cmp_options.set_external_data_path(model_path_.has_parent_path() ? model_path_.parent_path().string() : std::filesystem::current_path().string()); +#ifdef HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH + cmp_options.set_external_data_path(model_path_.parent_path().string()); #endif #endif prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); migraphx::program_parameters quant_params; - if ((int8_enable xor fp8_enable) and int8_calibration_cache_available) { - auto param_shapes = prog.get_parameter_shapes(); + if ((int8_enable ^ fp8_enable) && int8_calibration_cache_available) { + auto local_param_shapes = prog.get_parameter_shapes(); // Add input parameter data and the values they're set to - for (auto&& name : param_shapes.names()) { + for (auto&& name : local_param_shapes.names()) { if (map_input_name_index.count(name) > 0) { auto input_tensor = ctx.GetInput(map_input_name_index[name]); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); @@ -1498,19 +1486,19 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx_shape_datatype_t mgx_type; getMIGraphXType(tensor_type, mgx_type); - auto mgx_s = param_shapes[name]; + auto mgx_s = local_param_shapes[name]; if (mgx_type != mgx_s.type()) { LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; } - quant_params.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); + quant_params.add(name, migraphx::argument(local_param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); } } } - calibrate_and_quantize(prog, t, quant_params, fp16_enable, int8_enable, + calibrate_and_quantize(prog, t, quant_params, fp16_enable, bf16_enable, int8_enable, fp8_enable, int8_calibration_cache_available, map_dynamic_range); compile_program(prog, t, exhaustive_tune_); - save_compiled_model(prog, mgx_state->save_compiled_mode, mgx_state->save_compiled_path); + save_compiled_model(prog, model_cache_file); } mgx_state->prog = prog; @@ -1524,7 +1512,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (param_shapes.size() > 0) { for (auto&& name : param_shapes.names()) { if (map_input_name_index.count(name) > 0) { - LOGS_DEFAULT(INFO) << "Setting parameters for:" << name; + LOGS_DEFAULT(VERBOSE) << "Setting parameters for:" << name; auto input_tensor = ctx.GetInput(map_input_name_index[name]); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shape = tensor_info.GetShape(); @@ -1538,21 +1526,21 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; } - LOGS_DEFAULT(INFO) << "Writing Raw tensor data "; + LOGS_DEFAULT(VERBOSE) << "Writing Raw tensor data "; m.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); } - // It is a output argument + // It is an output argument else { - auto compute_output_index = [](const std::string& name) -> int { - std::string out_name_prefix = "#output_"; - auto pos = name.find(out_name_prefix); - if (pos == std::string::npos) { + auto compute_output_index = [](const std::string_view sv) -> int { + constexpr std::string_view out_name_prefix = "#output_"; + const auto pos = sv.find(out_name_prefix); + if (pos == std::string_view::npos) { return -1; } - std::string index_str = name.substr(pos + out_name_prefix.length()); - return std::stoi(index_str); + const auto index_str = sv.substr(pos + out_name_prefix.length()); + return ToInteger(Trim(index_str, std::isdigit)); }; int output_index = compute_output_index(name); @@ -1599,7 +1587,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& static_cast(rocm_stream))); } } - }; + } return Status::OK(); }; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index aecccdd54d697..99f790b9f9f7a 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -3,33 +3,37 @@ #pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include #include "core/framework/arena_extend_strategy.h" #include "core/framework/execution_provider.h" -#include +#include "core/framework/provider_options_utils.h" #include "core/providers/migraphx/migraphx_execution_provider_info.h" #include "core/providers/migraphx/migraphx_call.h" -#include -#include -#include +using namespace std::literals::string_view_literals; namespace onnxruntime { namespace migraphx_env_vars { -static const char kFP16Enable[] = "ORT_MIGRAPHX_FP16_ENABLE"; -static const char kFP8Enable[] = "ORT_MIGRAPHX_FP8_ENABLE"; -static const char kINT8Enable[] = "ORT_MIGRAPHX_INT8_ENABLE"; -static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; -static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"; -static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH"; -static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"; -static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL"; -static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILED_PATH"; -static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL"; -static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILED_PATH"; -static const char kExhaustiveTune[] = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"; - -}; // namespace migraphx_env_vars +constexpr auto kFP16Enable = "ORT_MIGRAPHX_FP16_ENABLE"sv; +constexpr auto kBF16Enable = "ORT_MIGRAPHX_BF16_ENABLE"sv; +constexpr auto kFP8Enable = "ORT_MIGRAPHX_FP8_ENABLE"sv; +constexpr auto kINT8Enable = "ORT_MIGRAPHX_INT8_ENABLE"sv; +constexpr auto kDumpModelOps = "ORT_MIGRAPHX_DUMP_MODEL_OPS"sv; +constexpr auto kINT8CalibrationTableName = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"sv; +constexpr auto kCachePath = "ORT_MIGRAPHX_CACHE_PATH"sv; +constexpr auto kINT8UseNativeMIGraphXCalibrationTable = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"sv; +constexpr auto kExhaustiveTune = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"sv; +constexpr auto kModelCachePath = "ORT_MIGRAPHX_MODEL_CACHE_PATH"sv; +} // namespace migraphx_env_vars // Information to construct kernel function state. struct MIGraphXFuncState { @@ -44,14 +48,12 @@ struct MIGraphXFuncState { std::mutex* mgx_mu_ptr = nullptr; bool no_input_shape = false; bool fp16_enable = false; + bool bf16_enable = false; bool fp8_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; std::unordered_map dynamic_range_map; - bool save_compiled_mode = false; - std::string save_compiled_path; - bool load_compiled_mode = false; - std::string load_compiled_path; + std::filesystem::path model_cache_dir; bool dump_model_ops = false; bool exhaustive_tune = false; }; @@ -60,11 +62,7 @@ struct MIGraphXFuncState { class MIGraphXExecutionProvider : public IExecutionProvider { public: explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info); - ~MIGraphXExecutionProvider(); - - void get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info); - void get_flags_from_env(); - void print_migraphx_ep_flags(); + ~MIGraphXExecutionProvider() override = default; Status Sync() const override; @@ -81,42 +79,55 @@ class MIGraphXExecutionProvider : public IExecutionProvider { common::Status Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) override; - virtual std::shared_ptr GetKernelRegistry() const override; + std::shared_ptr GetKernelRegistry() const override; std::unique_ptr GetDataTransfer() const override; - static AllocatorPtr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t migx_mem_limit, ArenaExtendStrategy arena_extend_strategy, - MIGraphXExecutionProviderExternalAllocatorInfo external_alloc_info, const OrtArenaCfg* arena_cfg); - std::unique_ptr GetSubGraph(const std::vector& graph_nodes_index, const GraphViewer& graph) const; void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; std::vector CreatePreferredAllocators() override; - int GetDeviceId() const override { return info_.device_id; } + int GetDeviceId() const override { return device_id_; } ProviderOptions GetProviderOptions() const override { - return MIGraphXExecutionProviderInfo::ToProviderOptions(info_); + return { + {std::string{migraphx_provider_option::kDeviceId}, MakeStringWithClassicLocale(device_id_)}, + {std::string{migraphx_provider_option::kFp16Enable}, MakeStringWithClassicLocale(fp16_enable_)}, + {std::string{migraphx_provider_option::kBf16Enable}, MakeStringWithClassicLocale(bf16_enable_)}, + {std::string{migraphx_provider_option::kFp8Enable}, MakeStringWithClassicLocale(fp8_enable_)}, + {std::string{migraphx_provider_option::kInt8Enable}, MakeStringWithClassicLocale(int8_enable_)}, + {std::string{migraphx_provider_option::kInt8CalibTable}, MakeStringWithClassicLocale(int8_calibration_table_name_)}, + {std::string{migraphx_provider_option::kInt8UseNativeCalibTable}, MakeStringWithClassicLocale(int8_use_native_calibration_table_)}, + {std::string{migraphx_provider_option::kExhaustiveTune}, MakeStringWithClassicLocale(exhaustive_tune_)}, + {std::string{migraphx_provider_option::kMemLimit}, MakeStringWithClassicLocale(mem_limit_)}, + {std::string{migraphx_provider_option::kArenaExtendStrategy}, EnumToName(arena_extend_strategy_mapping, arena_extend_strategy_)}, + {std::string{migraphx_provider_option::kGpuExternalAlloc}, MakeStringWithClassicLocale(external_alloc_)}, + {std::string{migraphx_provider_option::kGpuExternalFree}, MakeStringWithClassicLocale(external_free_)}, + {std::string{migraphx_provider_option::kGpuExternalEmptyCache}, MakeStringWithClassicLocale(external_empty_cache_)}, + {std::string{migraphx_provider_option::kModelCacheDir}, MakeStringWithClassicLocale(model_cache_path_)}}; } private: - MIGraphXExecutionProviderInfo info_; + OrtDevice::DeviceId device_id_{0}; bool fp16_enable_ = false; + bool bf16_enable_ = false; bool fp8_enable_ = false; bool int8_enable_ = false; - std::string int8_calibration_cache_name_; + std::string int8_calibration_table_name_; bool int8_calibration_cache_available_ = false; - bool int8_use_native_migraphx_calibration_table_ = false; - std::string calibration_cache_path_; + bool int8_use_native_calibration_table_ = false; + std::filesystem::path calibration_cache_path_{}; std::unordered_map dynamic_range_map_; - bool save_compiled_model_ = false; - std::string save_compiled_path_; - bool load_compiled_model_ = false; - std::string load_compiled_path_; + std::filesystem::path model_cache_path_{}; + std::set session_input_names; bool dump_model_ops_ = false; migraphx::target t_; std::mutex mgx_mu_; hipStream_t stream_ = nullptr; + hipDeviceProp_t device_prop_{}; bool exhaustive_tune_ = false; - mutable std::filesystem::path model_path_; + mutable std::filesystem::path model_path_{}; + size_t mem_limit_{std::numeric_limits::max()}; + ArenaExtendStrategy arena_extend_strategy_{ArenaExtendStrategy::kNextPowerOfTwo}; std::unordered_map map_progs_; std::unordered_map map_onnx_string_; @@ -125,6 +136,9 @@ class MIGraphXExecutionProvider : public IExecutionProvider { AllocatorPtr allocator_; std::unique_ptr metadef_id_generator_; + void* external_alloc_{nullptr}; + void* external_free_{nullptr}; + void* external_empty_cache_{nullptr}; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index cf21d791cfe6b..33ef366eb18e5 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -1,14 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/providers/shared_library/provider_api.h" #include "core/providers/migraphx/migraphx_execution_provider_info.h" #include "core/common/make_string.h" #include "core/common/parse_string.h" -#include "core/framework/provider_options_utils.h" -#include "migraphx_inc.h" -#include "migraphx_call.h" +#include "core/providers/migraphx/migraphx_inc.h" +#include "core/providers/migraphx/migraphx_call.h" namespace onnxruntime { @@ -17,118 +18,90 @@ const EnumNameMapping arena_extend_strategy_mapping{ {ArenaExtendStrategy::kSameAsRequested, "kSameAsRequested"}, }; -namespace migraphx { -namespace provider_option_names { -constexpr const char* kDeviceId = "device_id"; -constexpr const char* kFp16Enable = "trt_fp16_enable"; -constexpr const char* kFp8Enable = "migx_fp8_enable"; -constexpr const char* kInt8Enable = "migx_int8_enable"; -constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name"; -constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table"; -constexpr const char* kSaveCompiledModel = "migx_save_compiled_model"; -constexpr const char* kSaveModelPath = "migx_save_model_name"; -constexpr const char* kLoadCompiledModel = "migx_load_compiled_model"; -constexpr const char* kLoadModelPath = "migx_load_model_name"; -constexpr const char* kExhaustiveTune = "migx_exhaustive_tune"; -constexpr const char* kMemLimit = "migx_mem_limit"; -constexpr const char* kArenaExtendStrategy = "migx_arena_extend_strategy"; -constexpr const char* kGpuExternalAlloc = "migx_external_alloc"; -constexpr const char* kGpuExternalFree = "migx_external_free"; -constexpr const char* kGpuExternalEmptyCache = "migx_external_empty_cache"; - -} // namespace provider_option_names -} // namespace migraphx - -MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { - MIGraphXExecutionProviderInfo info{}; - void* alloc = nullptr; - void* free = nullptr; - void* empty_cache = nullptr; +MIGraphXExecutionProviderInfo::MIGraphXExecutionProviderInfo(const ProviderOptions& options) { ORT_THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( - migraphx::provider_option_names::kDeviceId, - [&info](const std::string& value_str) -> Status { - ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id)); + migraphx_provider_option::kDeviceId, + [this](const std::string& value_str) -> Status { + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, device_id)); int num_devices{}; ORT_RETURN_IF_ERROR(HIP_CALL(hipGetDeviceCount(&num_devices))); ORT_RETURN_IF_NOT( - 0 <= info.device_id && info.device_id < num_devices, - "Invalid device ID: ", info.device_id, + 0 <= device_id && device_id < num_devices, + "Invalid device ID: ", device_id, ", must be between 0 (inclusive) and ", num_devices, " (exclusive)."); return Status::OK(); }) .AddValueParser( - migraphx::provider_option_names::kGpuExternalAlloc, - [&alloc](const std::string& value_str) -> Status { - size_t address; + migraphx_provider_option::kGpuExternalAlloc, + [this](const std::string& value_str) -> Status { + std::uintptr_t address; ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); - alloc = reinterpret_cast(address); + external_alloc = reinterpret_cast(address); return Status::OK(); }) .AddValueParser( - migraphx::provider_option_names::kGpuExternalFree, - [&free](const std::string& value_str) -> Status { - size_t address; + migraphx_provider_option::kGpuExternalFree, + [this](const std::string& value_str) -> Status { + std::uintptr_t address; ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); - free = reinterpret_cast(address); + external_free = reinterpret_cast(address); return Status::OK(); }) .AddValueParser( - migraphx::provider_option_names::kGpuExternalEmptyCache, - [&empty_cache](const std::string& value_str) -> Status { - size_t address; + migraphx_provider_option::kGpuExternalEmptyCache, + [this](const std::string& value_str) -> Status { + std::uintptr_t address; ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); - empty_cache = reinterpret_cast(address); + external_empty_cache = reinterpret_cast(address); + return Status::OK(); + }) + .AddValueParser( + migraphx_provider_option::kModelCacheDir, + [this](const std::string& value_str) -> Status { + model_cache_dir = ToPathString(value_str); return Status::OK(); }) - .AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable) - .AddAssignmentToReference(migraphx::provider_option_names::kFp8Enable, info.fp8_enable) - .AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable) - .AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model) - .AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model) - .AddAssignmentToReference(migraphx::provider_option_names::kExhaustiveTune, info.exhaustive_tune) - .AddAssignmentToReference(migraphx::provider_option_names::kMemLimit, info.mem_limit) - .AddAssignmentToEnumReference(migraphx::provider_option_names::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy) + .AddAssignmentToReference(migraphx_provider_option::kFp16Enable, fp16_enable) + .AddAssignmentToReference(migraphx_provider_option::kBf16Enable, bf16_enable) + .AddAssignmentToReference(migraphx_provider_option::kFp8Enable, fp8_enable) + .AddAssignmentToReference(migraphx_provider_option::kInt8Enable, int8_enable) + .AddAssignmentToReference(migraphx_provider_option::kInt8UseNativeCalibTable, int8_use_native_calibration_table) + .AddAssignmentToReference(migraphx_provider_option::kInt8CalibTable, int8_calibration_table_name) + .AddAssignmentToReference(migraphx_provider_option::kExhaustiveTune, exhaustive_tune) + .AddAssignmentToReference(migraphx_provider_option::kMemLimit, mem_limit) + .AddAssignmentToEnumReference(migraphx_provider_option::kArenaExtendStrategy, arena_extend_strategy_mapping, arena_extend_strategy) .Parse(options)); - - MIGraphXExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache}; - info.external_allocator_info = alloc_info; - - return info; } -ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXExecutionProviderInfo& info) { - const ProviderOptions options{ - {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, - {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, - {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.fp8_enable)}, - {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, - {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)}, - {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)}, - {migraphx::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.mem_limit)}, - {migraphx::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.alloc))}, - {migraphx::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.free))}, - {migraphx::provider_option_names::kGpuExternalEmptyCache, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.empty_cache))}, - {migraphx::provider_option_names::kArenaExtendStrategy, - EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, - {migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.exhaustive_tune)}, - }; - return options; +MIGraphXExecutionProviderInfo::MIGraphXExecutionProviderInfo(const OrtMIGraphXProviderOptions& options) noexcept + : device_id{static_cast(options.device_id)}, + fp16_enable{options.migraphx_fp16_enable != 0}, + fp8_enable{options.migraphx_fp8_enable != 0}, + int8_enable{options.migraphx_int8_enable != 0}, + exhaustive_tune{options.migraphx_exhaustive_tune != 0}, + mem_limit{options.migraphx_mem_limit}, + arena_extend_strategy{options.migraphx_arena_extend_strategy} { } -ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGraphXProviderOptions& info) { - const ProviderOptions options{ - {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, - {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, - {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.migraphx_fp8_enable)}, - {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, - {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)}, - {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)}, - {migraphx::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.migraphx_mem_limit)}, - {migraphx::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast(info.migraphx_arena_extend_strategy))}, - {migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.migraphx_exhaustive_tune)}, +ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions() const { + return { + {std::string{migraphx_provider_option::kDeviceId}, MakeStringWithClassicLocale(device_id)}, + {std::string{migraphx_provider_option::kFp16Enable}, MakeStringWithClassicLocale(fp16_enable)}, + {std::string{migraphx_provider_option::kBf16Enable}, MakeStringWithClassicLocale(bf16_enable)}, + {std::string{migraphx_provider_option::kFp8Enable}, MakeStringWithClassicLocale(fp8_enable)}, + {std::string{migraphx_provider_option::kInt8Enable}, MakeStringWithClassicLocale(int8_enable)}, + {std::string{migraphx_provider_option::kInt8CalibTable}, MakeStringWithClassicLocale(int8_calibration_table_name)}, + {std::string{migraphx_provider_option::kInt8UseNativeCalibTable}, MakeStringWithClassicLocale(int8_use_native_calibration_table)}, + {std::string{migraphx_provider_option::kMemLimit}, MakeStringWithClassicLocale(mem_limit)}, + {std::string{migraphx_provider_option::kArenaExtendStrategy}, EnumToName(arena_extend_strategy_mapping, arena_extend_strategy)}, + {std::string{migraphx_provider_option::kExhaustiveTune}, MakeStringWithClassicLocale(exhaustive_tune)}, + {std::string{migraphx_provider_option::kGpuExternalAlloc}, MakeStringWithClassicLocale(external_alloc)}, + {std::string{migraphx_provider_option::kGpuExternalFree}, MakeStringWithClassicLocale(external_free)}, + {std::string{migraphx_provider_option::kGpuExternalEmptyCache}, MakeStringWithClassicLocale(external_empty_cache)}, + {std::string{migraphx_provider_option::kModelCacheDir}, MakeStringWithClassicLocale(model_cache_dir)}, }; - return options; } + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index a598052c5f025..414254aaa2629 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -3,70 +3,79 @@ #pragma once +#include #include #include +#include #include "core/framework/ortdevice.h" #include "core/common/hash_combine.h" #include "core/framework/arena_extend_strategy.h" #include "core/framework/provider_options.h" +#include "core/framework/provider_options_utils.h" #include "core/session/onnxruntime_c_api.h" -namespace onnxruntime { - -// Information needed to construct MIGraphX execution providers. -struct MIGraphXExecutionProviderExternalAllocatorInfo { - void* alloc{nullptr}; - void* free{nullptr}; - void* empty_cache{nullptr}; - - MIGraphXExecutionProviderExternalAllocatorInfo() { - alloc = nullptr; - free = nullptr; - empty_cache = nullptr; - } +using namespace std::literals::string_view_literals; - MIGraphXExecutionProviderExternalAllocatorInfo(void* a, void* f, void* e) { - alloc = a; - free = f; - empty_cache = e; - } +namespace onnxruntime { - bool UseExternalAllocator() const { - return (alloc != nullptr) && (free != nullptr); - } -}; +namespace migraphx_provider_option { +constexpr auto kDeviceId = "device_id"sv; +constexpr auto kFp16Enable = "migraphx_fp16_enable"sv; +constexpr auto kBf16Enable = "migraphx_bf16_enable"sv; +constexpr auto kFp8Enable = "migraphx_fp8_enable"sv; +constexpr auto kInt8Enable = "migraphx_int8_enable"sv; +constexpr auto kInt8CalibTable = "migraphx_int8_calibration_table_name"sv; +constexpr auto kInt8UseNativeCalibTable = "migraphx_int8_use_native_calibration_table"sv; +constexpr auto kExhaustiveTune = "migraphx_exhaustive_tune"sv; +constexpr auto kMemLimit = "migraphx_mem_limit"sv; +constexpr auto kArenaExtendStrategy = "migraphx_arena_extend_strategy"sv; +constexpr auto kGpuExternalAlloc = "migraphx_external_alloc"sv; +constexpr auto kGpuExternalFree = "migraphx_external_free"sv; +constexpr auto kGpuExternalEmptyCache = "migraphx_external_empty_cache"sv; +constexpr auto kModelCacheDir = "migraphx_model_cache_dir"sv; +} // namespace migraphx_provider_option + +extern const EnumNameMapping arena_extend_strategy_mapping; // Information needed to construct trt execution providers. struct MIGraphXExecutionProviderInfo { - std::string target_device; + std::string target_device{"gpu"}; OrtDevice::DeviceId device_id{0}; bool fp16_enable{false}; + bool bf16_enable{false}; bool fp8_enable{false}; bool int8_enable{false}; - std::string int8_calibration_table_name{""}; + std::string int8_calibration_table_name{}; bool int8_use_native_calibration_table{false}; - bool save_compiled_model{true}; - std::string save_model_file{"./compiled_model.mxr"}; - bool load_compiled_model{true}; - std::string load_model_file{"./compiled_model.mxr"}; + std::filesystem::path model_cache_dir{}; bool exhaustive_tune{false}; - size_t mem_limit{std::numeric_limits::max()}; // Will be over-ridden by contents of `default_memory_arena_cfg` (if specified) - ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; // Will be over-ridden by contents of `default_memory_arena_cfg` (if specified) + size_t mem_limit{std::numeric_limits::max()}; + ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; OrtArenaCfg* default_memory_arena_cfg{nullptr}; - MIGraphXExecutionProviderExternalAllocatorInfo external_allocator_info{}; - static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); - static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info); - static ProviderOptions ToProviderOptions(const OrtMIGraphXProviderOptions& info); + void* external_alloc{nullptr}; + void* external_free{nullptr}; + void* external_empty_cache{nullptr}; + + bool UseExternalAlloc() const { + return external_alloc != nullptr && external_free != nullptr; + } + + MIGraphXExecutionProviderInfo() = default; + + explicit MIGraphXExecutionProviderInfo(const ProviderOptions& options); + explicit MIGraphXExecutionProviderInfo(const OrtMIGraphXProviderOptions& options) noexcept; + ProviderOptions ToProviderOptions() const; }; + } // namespace onnxruntime template <> struct std::hash<::onnxruntime::MIGraphXExecutionProviderInfo> { - size_t operator()(const ::onnxruntime::MIGraphXExecutionProviderInfo& info) const { + size_t operator()(const ::onnxruntime::MIGraphXExecutionProviderInfo& info) const noexcept { size_t value{0xbc9f1d34}; // seed // Bits: device_id (16), arena_extend_strategy (reserved 2), boolean options (1 each) @@ -75,17 +84,21 @@ struct std::hash<::onnxruntime::MIGraphXExecutionProviderInfo> { (static_cast(info.fp16_enable) << 18) ^ (static_cast(info.int8_enable) << 19) ^ (static_cast(info.int8_use_native_calibration_table) << 20) ^ - (static_cast(info.save_compiled_model) << 21) ^ - (static_cast(info.load_compiled_model) << 22) ^ - (static_cast(info.exhaustive_tune) << 23); + (static_cast(info.exhaustive_tune) << 21) ^ + (static_cast(info.bf16_enable) << 22); + onnxruntime::HashCombine(data, value); + onnxruntime::HashCombine(info.target_device, value); + onnxruntime::HashCombine(info.default_memory_arena_cfg, value); + onnxruntime::HashCombine(info.int8_calibration_table_name, value); + onnxruntime::HashCombine(info.model_cache_dir, value); onnxruntime::HashCombine(info.mem_limit, value); // Memory pointers - onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.alloc), value); - onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.free), value); - onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.empty_cache), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_alloc), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_free), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_empty_cache), value); // The default memory arena cfg is not used in hashing right now. return value; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h index 9274b5696185c..cce90f3ef82be 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h @@ -3,23 +3,27 @@ #pragma once +#include +#include +#include #include -#include -#include #include -#include #include +#include +#include +#include #include "flatbuffers/idl.h" #include "core/providers/migraphx/ort_trt_int8_cal_table.fbs.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/execution_provider.h" #include "core/common/path_string.h" +#include "core/framework/murmurhash3.h" namespace fs = std::filesystem; namespace onnxruntime { -bool IsGraphInput(const GraphViewer& graph, const std::string& name) { +inline bool IsGraphInput(const GraphViewer& graph, const std::string& name) { const auto& graph_inputs = graph.GetInputs(); std::vector input_names(graph_inputs.size()); std::transform(graph_inputs.begin(), graph_inputs.end(), input_names.begin(), [](auto in) { @@ -28,12 +32,12 @@ bool IsGraphInput(const GraphViewer& graph, const std::string& name) { return (std::find(input_names.begin(), input_names.end(), name) != input_names.end()); } -bool IsGraphInitializer(const GraphViewer& graph, const std::string& name, [[maybe_unused]] bool check_outer_scope = true) { +inline bool IsGraphInitializer(const GraphViewer& graph, const std::string& name, [[maybe_unused]] bool check_outer_scope = true) { const ONNX_NAMESPACE::TensorProto* initializer = nullptr; return graph.GetInitializedTensor(name, initializer); } -const Node* GetInputNode(const Node& node, int arg_index) { +inline const Node* GetInputNode(const Node& node, int arg_index) { int index = 0; for (auto nit = node.InputNodesBegin(); nit != node.InputNodesEnd(); ++nit, ++index) { if (index == arg_index) { @@ -44,7 +48,7 @@ const Node* GetInputNode(const Node& node, int arg_index) { return nullptr; } -std::size_t getNodeInputNum(const Node& node) { +inline std::size_t getNodeInputNum(const Node& node) { std::size_t node_num = 0; for (auto it = node.InputNodesBegin(); it != node.InputNodesEnd(); ++it) { node_num++; @@ -53,14 +57,14 @@ std::size_t getNodeInputNum(const Node& node) { return node_num; } -bool isInputNode(const Node* node, const std::string& name) { +inline bool isInputNode(const Node* node, const std::string& name) { auto outputs = node->OutputDefs(); return std::any_of(outputs.begin(), outputs.end(), [&](auto out) { return (out->Name() == name); }); } -bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector& input_nodes) { +inline bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector& input_nodes) { if (node == nullptr) { return false; } @@ -113,10 +117,10 @@ bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector return true; } -bool canEvalNodeArgument(const GraphViewer& graph, - const Node* node, - std::vector indices, - std::vector& input_nodes) { +inline bool canEvalNodeArgument(const GraphViewer& graph, + const Node* node, + std::vector indices, + std::vector& input_nodes) { input_nodes.clear(); std::vector in_nodes; for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit) { @@ -152,7 +156,7 @@ bool canEvalNodeArgument(const GraphViewer& graph, return true; } -float ConvertSinglePrecisionIEEE754ToFloat(uint32_t input) { +inline float ConvertSinglePrecisionIEEE754ToFloat(uint32_t input) { int s = (input >> 31) & 0x01; int e = ((input & 0x7f800000) >> 23) - 127; int p = -1; @@ -184,12 +188,12 @@ float ConvertSinglePrecisionIEEE754ToFloat(uint32_t input) { * Taken from the tensorRT EP to allow MIGraphX EP to reuse calibration tables for existing models * */ -bool ReadDynamicRange(const std::string file_name, - const bool is_calibration_table, - std::unordered_map& dynamic_range_map) { - std::ifstream infile(file_name, std::ios::binary | std::ios::in); - if (!infile) { +inline bool ReadDynamicRange(const std::filesystem::path& filename, + const bool is_calibration_table, + std::unordered_map& dynamic_range_map) { + std::ifstream infile{filename, std::ios::binary | std::ios::in}; + if (!infile.good()) { return false; } @@ -215,7 +219,7 @@ bool ReadDynamicRange(const std::string file_name, dynamic_range_map[tensor_name] = dynamic_range; } } else { - throw std::runtime_error("This is not a TensorRT generated calibration table " + file_name); + throw std::runtime_error("This is not a TensorRT generated calibration table " + filename.string()); } } } else { @@ -240,14 +244,111 @@ bool ReadDynamicRange(const std::string file_name, * Get cache by name * */ -std::string GetCachePath(const std::string& root, const std::string& name) { - if (root.empty()) { - return name; +inline std::filesystem::path GetCachePath(const std::filesystem::path& root, std::string_view name) { + return root.empty() ? std::filesystem::path{ToPathString(name)} : root / ToPathString(name); +} + +inline std::string GenerateGraphId(const GraphViewer& graph_viewer) { + HashValue model_hash; + + // find the top level graph + const Graph* cur_graph = &graph_viewer.GetGraph(); + while (cur_graph->IsSubgraph()) { + cur_graph = cur_graph->ParentGraph(); + } + + const Graph& main_graph = *cur_graph; + uint32_t hash[4] = {0, 0, 0, 0}; + + auto hash_str = [&hash](const std::string& str) { + MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); + }; + + // Use the model's file name instead of the entire path to avoid cache regeneration if a path changes + const fs::path path{main_graph.ModelPath()}; + + if (path.has_filename()) { + const auto model_name = path.filename().string(); + + LOGS_DEFAULT(INFO) << "Model name is '" << model_name << "'"; + // Ensure enough characters are hashed in case model names are too short + const size_t model_name_length = model_name.length(); + constexpr size_t hash_string_length = 500; + std::string repeat_model_name = model_name; + for (size_t i = model_name_length; i > 0 && i < hash_string_length; i += model_name_length) { + repeat_model_name += model_name; + } + hash_str(repeat_model_name); } else { - fs::path path = root; - path.append(name); - return path.string(); + LOGS_DEFAULT(INFO) << "Model path is empty"; + } + + // fingerprint current graph by hashing graph inputs + for (const auto* node_arg : graph_viewer.GetInputsIncludingInitializers()) { + hash_str(node_arg->Name()); + } + + // hashing outputs, inputs and inputs shapes of each node + const int number_of_ort_nodes = graph_viewer.NumberOfNodes(); + std::vector nodes_vector(number_of_ort_nodes); + std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); + const std::vector& node_index = graph_viewer.GetNodesInTopologicalOrder(); + for (const auto& index : nodes_vector) { + const auto& node = graph_viewer.GetNode(node_index[index]); + for (const auto* node_arg : node->OutputDefs()) { + if (node_arg != nullptr && node_arg->Exists()) { + hash_str(node_arg->Name()); + } + } + for (const auto* node_arg : node->InputDefs()) { + if (node_arg != nullptr && node_arg->Exists()) { + hash_str(node_arg->Name()); + if (node_arg->Shape() == nullptr) { + continue; + } + int dim_size = node_arg->Shape()->dim_size(); + for (int i = 0; i < dim_size; i++) { + hash_str(std::to_string(node_arg->Shape()->dim(i).dim_value())); + } + } + } + } + +#ifdef __linux__ + hash_str("LINUX"); +#elif defined(_WIN32) + hash_str("WINDOWS"); +#endif + + model_hash = hash[0] | static_cast(hash[1]) << 32; + + std::array s{}; + auto [ptr, ec] = std::to_chars(s.data(), s.data() + s.size(), model_hash, 16); + return std::string{s.data(), ptr}; +} + +inline std::string_view TrimLeft(std::string_view sv, int (*fn)(int) = std::isspace) { + return sv.substr(0, sv.end() - std::find_if(sv.begin(), sv.end(), [fn](int ch) { + return fn(ch); + })); +} + +inline std::string_view TrimRight(std::string_view sv, int (*fn)(int) = std::isspace) { + return sv.substr(sv.end() - std::find_if(sv.rbegin(), sv.rend(), [fn](int ch) { + return fn(ch); + }).base()); +} + +inline std::string_view Trim(std::string_view sv, int (*fn)(int) = std::isspace) { + return TrimRight(TrimLeft(sv, fn), fn); +} + +inline int ToInteger(const std::string_view sv) { + int result = 0; + if (auto [_, ec] = std::from_chars(sv.data(), sv.data() + sv.length(), result); ec == std::errc()) { + return result; } + ORT_THROW("invalid input for conversion to integer"); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_inc.h b/onnxruntime/core/providers/migraphx/migraphx_inc.h index 2b035b20f619f..49e838747892f 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_inc.h +++ b/onnxruntime/core/providers/migraphx/migraphx_inc.h @@ -6,3 +6,4 @@ #include #include #include +#include diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 923a39a7d2903..626758bce36d7 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -1,28 +1,36 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License + #include +#include +#include +#include +#include + +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#include +#include +#endif #include "core/providers/shared_library/provider_api.h" #include "core/providers/migraphx/migraphx_provider_factory.h" -#include "migraphx_execution_provider.h" -#include "migraphx_execution_provider_info.h" -#include "migraphx_provider_factory_creator.h" -#include "migraphx_allocator.h" -#include "gpu_data_transfer.h" +#include "core/providers/migraphx/migraphx_execution_provider.h" +#include "core/providers/migraphx/migraphx_execution_provider_info.h" +#include "core/providers/migraphx/migraphx_allocator.h" +#include "core/providers/migraphx/gpu_data_transfer.h" #include "core/framework/provider_options.h" #include "core/session/onnxruntime_c_api.h" -using namespace onnxruntime; - namespace onnxruntime { void InitializeRegistry(); void DeleteRegistry(); struct MIGraphXProviderFactory : IExecutionProviderFactory { - MIGraphXProviderFactory(const MIGraphXExecutionProviderInfo& info) : info_{info} {} - ~MIGraphXProviderFactory() override {} + explicit MIGraphXProviderFactory(MIGraphXExecutionProviderInfo info) : info_{std::move(info)} {} + ~MIGraphXProviderFactory() override = default; std::unique_ptr CreateProvider() override; @@ -35,11 +43,11 @@ std::unique_ptr MIGraphXProviderFactory::CreateProvider() { } struct ProviderInfo_MIGraphX_Impl final : ProviderInfo_MIGraphX { - std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) override { + std::unique_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, const char* name) override { return std::make_unique(device_id, name); } - std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) override { + std::unique_ptr CreateMIGraphXPinnedAllocator(OrtDevice::DeviceId device_id, const char* name) override { return std::make_unique(device_id, name); } @@ -61,14 +69,39 @@ struct ProviderInfo_MIGraphX_Impl final : ProviderInfo_MIGraphX { HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyDeviceToHost)); } - std::shared_ptr CreateMIGraphXAllocator(int16_t device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override { - return MIGraphXExecutionProvider::CreateMIGraphXAllocator(device_id, migx_mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg); + std::shared_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, + void* alloc_fn, void* free_fn, void* empty_cache_fn, const OrtArenaCfg* default_memory_arena_cfg) override { + if (alloc_fn != nullptr && free_fn != nullptr) { + AllocatorCreationInfo default_memory_info{ + [alloc_fn, free_fn, empty_cache_fn](OrtDevice::DeviceId id) { + return std::make_unique(id, HIP, alloc_fn, free_fn, empty_cache_fn); + }, + device_id, false}; + + return CreateAllocator(default_memory_info); + } + AllocatorCreationInfo default_memory_info{ + [](OrtDevice::DeviceId id) { + return std::make_unique(id, HIP); + }, + device_id, + true, + {default_memory_arena_cfg ? *default_memory_arena_cfg + : OrtArenaCfg(mem_limit, static_cast(arena_extend_strategy), + -1, -1, -1, -1L)}, + // make it stream aware + true}; + + // ROCM malloc/free is expensive so always use an arena + return CreateAllocator(default_memory_info); } } g_info; -struct MIGraphX_Provider : Provider { +struct MIGraphX_Provider final : Provider { void* GetInfo() override { return &g_info; } + virtual ~MIGraphX_Provider() = default; + std::shared_ptr CreateExecutionProviderFactory(int device_id) override { MIGraphXExecutionProviderInfo info; info.device_id = static_cast(device_id); @@ -76,72 +109,49 @@ struct MIGraphX_Provider : Provider { return std::make_shared(info); } + // Method uses ProviderOptions, and not OrtMIGraphXProviderOptions (obsolete) std::shared_ptr CreateExecutionProviderFactory(const void* provider_options) override { - auto& options = *reinterpret_cast(provider_options); - MIGraphXExecutionProviderInfo info; - info.device_id = static_cast(options.device_id); - info.target_device = "gpu"; - info.fp16_enable = options.migraphx_fp16_enable; - info.fp8_enable = options.migraphx_fp8_enable; - info.exhaustive_tune = options.migraphx_exhaustive_tune; - info.int8_enable = options.migraphx_int8_enable; - info.int8_calibration_table_name = ""; - if (options.migraphx_int8_calibration_table_name != nullptr) { - info.int8_calibration_table_name = options.migraphx_int8_calibration_table_name; - } - info.int8_use_native_calibration_table = options.migraphx_use_native_calibration_table != 0; - info.save_compiled_model = options.migraphx_save_compiled_model; - info.save_model_file = ""; - if (options.migraphx_save_model_path != nullptr) { - info.save_model_file = options.migraphx_save_model_path; + if (provider_options != nullptr) { + return std::make_shared( + MIGraphXExecutionProviderInfo{*static_cast(provider_options)}); } - info.load_compiled_model = options.migraphx_load_compiled_model; - info.load_model_file = ""; - if (options.migraphx_load_model_path != nullptr) { - info.load_model_file = options.migraphx_load_model_path; - } - info.arena_extend_strategy = static_cast(options.migraphx_arena_extend_strategy); - info.mem_limit = options.migraphx_mem_limit; - return std::make_shared(info); + return nullptr; } void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override { - auto internal_options = onnxruntime::MIGraphXExecutionProviderInfo::FromProviderOptions(options); - auto& migx_options = *reinterpret_cast(provider_options); - migx_options.device_id = internal_options.device_id; - migx_options.migraphx_fp16_enable = internal_options.fp16_enable; - migx_options.migraphx_fp8_enable = internal_options.fp8_enable; - migx_options.migraphx_int8_enable = internal_options.int8_enable; - migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune; - - char* dest = nullptr; - auto str_size = internal_options.int8_calibration_table_name.size(); - if (str_size == 0) { - migx_options.migraphx_int8_calibration_table_name = nullptr; + MIGraphXExecutionProviderInfo internal_options{options}; + const auto migx_options = static_cast(provider_options); + migx_options->device_id = internal_options.device_id; + migx_options->migraphx_fp16_enable = internal_options.fp16_enable; + migx_options->migraphx_fp8_enable = internal_options.fp8_enable; + migx_options->migraphx_int8_enable = internal_options.int8_enable; + migx_options->migraphx_exhaustive_tune = internal_options.exhaustive_tune; + + if (internal_options.int8_calibration_table_name.empty()) { + migx_options->migraphx_int8_calibration_table_name = nullptr; } else { - dest = new char[str_size + 1]; + auto str_size = internal_options.int8_calibration_table_name.size(); + auto dest = new char[str_size + 1]; #ifdef _MSC_VER strncpy_s(dest, str_size + 1, internal_options.int8_calibration_table_name.c_str(), str_size); #else strncpy(dest, internal_options.int8_calibration_table_name.c_str(), str_size); #endif dest[str_size] = '\0'; - migx_options.migraphx_int8_calibration_table_name = (const char*)dest; + migx_options->migraphx_int8_calibration_table_name = static_cast(dest); } - migx_options.migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table; + migx_options->migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table; - migx_options.migraphx_save_compiled_model = internal_options.save_compiled_model; - migx_options.migraphx_save_model_path = internal_options.save_model_file.c_str(); - migx_options.migraphx_load_compiled_model = internal_options.load_compiled_model; - migx_options.migraphx_load_model_path = internal_options.load_model_file.c_str(); - migx_options.migraphx_arena_extend_strategy = static_cast(internal_options.arena_extend_strategy); - migx_options.migraphx_mem_limit = internal_options.mem_limit; + migx_options->migraphx_arena_extend_strategy = static_cast(internal_options.arena_extend_strategy); + migx_options->migraphx_mem_limit = internal_options.mem_limit; } ProviderOptions GetProviderOptions(const void* provider_options) override { - auto& options = *reinterpret_cast(provider_options); - return onnxruntime::MIGraphXExecutionProviderInfo::ToProviderOptions(options); + return provider_options != nullptr ? MIGraphXExecutionProviderInfo{ + *static_cast(provider_options)} + .ToProviderOptions() + : ProviderOptions{}; } Status CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, @@ -152,19 +162,29 @@ struct MIGraphX_Provider : Provider { const OrtLogger& logger, std::unique_ptr& ep) override { ORT_UNUSED_PARAMETER(num_devices); - const ConfigOptions* config_options = &session_options.GetConfigOptions(); - - std::array configs_array = {&provider_options, config_options}; - OrtMIGraphXProviderOptions migraphx_options; - UpdateProviderOptions(&migraphx_options, provider_options); - - auto ep_factory = CreateExecutionProviderFactory(&migraphx_options); + const auto ep_factory = CreateExecutionProviderFactory(&provider_options); ep = ep_factory->CreateProvider(session_options, logger); - return Status::OK(); } void Initialize() override { +#ifdef _WIN32 + HMODULE module = nullptr; + if (GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | + GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, + static_cast(static_cast(InitializeRegistry)), + &module) != 0) { + std::vector pathBuf; + for (;;) { + pathBuf.resize(pathBuf.size() + MAX_PATH); + if (const auto writen = GetModuleFileNameW(module, pathBuf.data(), static_cast(pathBuf.size())); writen < pathBuf.size()) { + break; + } + } + std::filesystem::path path(pathBuf.begin(), pathBuf.end()); + SetDllDirectoryW(path.parent_path().native().c_str()); + } +#endif InitializeRegistry(); } diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h index d1c9457bafa0f..c23c9947c8d9b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h @@ -1,22 +1,23 @@ -// Copyright 2019 AMD AMDMIGraphX +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License -#include "core/framework/provider_options.h" -#include "onnxruntime_c_api.h" +#pragma once + +#include + +#include "core/framework/arena_extend_strategy.h" +#include "core/framework/ortdevice.h" namespace onnxruntime { class IAllocator; -class IDataTransfer; -struct IExecutionProviderFactory; -struct MIGraphXExecutionProviderInfo; -enum class ArenaExtendStrategy : int32_t; -struct MIGraphXExecutionProviderExternalAllocatorInfo; struct ProviderInfo_MIGraphX { - virtual std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) = 0; - virtual std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) = 0; + virtual std::unique_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, const char* name) = 0; + virtual std::unique_ptr CreateMIGraphXPinnedAllocator(OrtDevice::DeviceId device_id, const char* name) = 0; virtual void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, size_t count) = 0; virtual void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, size_t count) = 0; - virtual std::shared_ptr CreateMIGraphXAllocator(int16_t device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) = 0; + virtual std::shared_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t mem_limit, + ArenaExtendStrategy arena_extend_strategy, void* alloc_fn, void* free_fn, void* empty_cache_fn, const OrtArenaCfg* default_memory_arena_cfg) = 0; protected: ~ProviderInfo_MIGraphX() = default; // Can only be destroyed through a subclass instance diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h b/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h index 02d30ad0f6fbb..db169b9e2f5a9 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h @@ -6,6 +6,7 @@ #include #include "core/providers/providers.h" +#include "core/framework/provider_options.h" struct OrtMIGraphXProviderOptions; @@ -14,5 +15,6 @@ namespace onnxruntime { struct MIGraphXProviderFactoryCreator { static std::shared_ptr Create(int device_id); static std::shared_ptr Create(const OrtMIGraphXProviderOptions* options); + static std::shared_ptr Create(const ProviderOptions&); }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc index 6e492327a73a3..0baa8a1c67c67 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc @@ -1,17 +1,26 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include "migraphx_stream_handle.h" +#include +#include + +#include "core/providers/resource.h" +#include "core/providers/migraphx/migraphx_stream_handle.h" + +#define MIGRAPHX_RESOURCE_VERSION 1 namespace onnxruntime { -struct MIGraphXNotification : public synchronize::Notification { - MIGraphXNotification(Stream& s) : Notification(s) { +enum MIGraphXResource { + hip_stream_t = rocm_resource_offset +}; + +struct MIGraphXNotification : synchronize::Notification { + explicit MIGraphXNotification(Stream& s) : Notification(s) { HIP_CALL_THROW(hipEventCreateWithFlags(&event_, hipEventDisableTiming)); } - ~MIGraphXNotification() { + ~MIGraphXNotification() override { if (event_) HIP_CALL_THROW(hipEventDestroy(event_)); } @@ -21,19 +30,19 @@ struct MIGraphXNotification : public synchronize::Notification { HIP_CALL_THROW(hipEventRecord(event_, static_cast(GetStream().GetHandle()))); } - void wait_on_device(Stream& device_stream) { - ORT_ENFORCE(device_stream.GetDevice().Type() == OrtDevice::GPU, "Unexpected device:", - device_stream.GetDevice().ToString()); - // launch a wait command to the migraphx stream - HIP_CALL_THROW(hipStreamWaitEvent(static_cast(device_stream.GetHandle()), event_, 0)); - }; + void wait_on_device(Stream* device_stream) const { + if (device_stream != nullptr) { + ORT_ENFORCE(device_stream->GetDevice().Type() == OrtDevice::GPU, "Unexpected device:", device_stream->GetDevice().ToString()); + // launch a wait command to the migraphx stream + HIP_CALL_THROW(hipStreamWaitEvent(static_cast(device_stream->GetHandle()), event_, 0)); + } + } - void wait_on_host() { - // CUDA_CALL_THROW(cudaStreamSynchronize(stream_)); + void wait_on_host() const { HIP_CALL_THROW(hipEventSynchronize(event_)); } - hipEvent_t event_; + hipEvent_t event_{}; }; MIGraphXStream::MIGraphXStream(hipStream_t stream, @@ -41,15 +50,14 @@ MIGraphXStream::MIGraphXStream(hipStream_t stream, AllocatorPtr cpu_allocator, bool release_cpu_buffer_on_migraphx_stream) : Stream(stream, device), - cpu_allocator_(cpu_allocator), + cpu_allocator_(std::move(cpu_allocator)), release_cpu_buffer_on_migraphx_stream_(release_cpu_buffer_on_migraphx_stream) { } MIGraphXStream::~MIGraphXStream() { - ORT_IGNORE_RETURN_VALUE(CleanUpOnRunEnd()); + ORT_IGNORE_RETURN_VALUE(MIGraphXStream::CleanUpOnRunEnd()); if (own_stream_) { - auto* handle = GetHandle(); - if (handle) + if (auto* handle = GetHandle()) HIP_CALL_THROW(hipStreamDestroy(static_cast(handle))); } } @@ -87,12 +95,12 @@ struct CpuBuffersInfo { std::unique_ptr buffers; // CPU buffer buffers[i]. // Number of buffer points in "buffers". - size_t n_buffers; + size_t n_buffers{}; }; static void ReleaseCpuBufferCallback(void* raw_info) { std::unique_ptr info = std::make_unique(); - info.reset(reinterpret_cast(raw_info)); + info.reset(static_cast(raw_info)); for (size_t i = 0; i < info->n_buffers; ++i) { info->allocator->Free(info->buffers[i]); } @@ -124,29 +132,25 @@ Status MIGraphXStream::CleanUpOnRunEnd() { } void* MIGraphXStream::GetResource(int version, int id) const { - ORT_ENFORCE(version <= ORT_ROCM_RESOURCE_VERSION, "resource version unsupported!"); - void* resource{}; - switch (id) { - case RocmResource::hip_stream_t: - return reinterpret_cast(GetHandle()); - default: - break; + ORT_ENFORCE(version <= MIGRAPHX_RESOURCE_VERSION, "resource version unsupported!"); + if (id == hip_stream_t) { + return GetHandle(); } - return resource; + return nullptr; } // CPU Stream command handles void WaitMIGraphXNotificationOnDevice(Stream* stream, synchronize::Notification& notification) { - static_cast(¬ification)->wait_on_device(*stream); + dynamic_cast(¬ification)->wait_on_device(stream); } void WaitMIGraphXNotificationOnHost(Stream* /*stream*/, synchronize::Notification& notification) { - static_cast(¬ification)->wait_on_host(); + dynamic_cast(¬ification)->wait_on_host(); } void RegisterMIGraphXStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, const OrtDevice::DeviceType device_type, - AllocatorPtr cpu_allocator, + const AllocatorPtr& cpu_allocator, bool release_cpu_buffer_on_migraphx_stream, hipStream_t external_stream, bool use_existing_stream) { @@ -154,19 +158,20 @@ void RegisterMIGraphXStreamHandles(IStreamCommandHandleRegistry& stream_handle_r stream_handle_registry.RegisterWaitFn(device_type, device_type, WaitMIGraphXNotificationOnDevice); // wait migraphx notification on cpu ep stream_handle_registry.RegisterWaitFn(device_type, OrtDevice::CPU, WaitMIGraphXNotificationOnHost); - if (!use_existing_stream) + if (!use_existing_stream) { stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_migraphx_stream](const OrtDevice& device) { HIP_CALL_THROW(hipSetDevice(device.Id())); hipStream_t stream = nullptr; HIP_CALL_THROW(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); return std::make_unique(stream, device, cpu_allocator, release_cpu_buffer_on_migraphx_stream); }); - else + } else { stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_migraphx_stream, external_stream](const OrtDevice& device) { return std::make_unique(external_stream, device, cpu_allocator, release_cpu_buffer_on_migraphx_stream); }); + } } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h index 886103690c661..132ae5fc09d13 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h @@ -2,12 +2,15 @@ // Licensed under the MIT License. #pragma once + +#include +#include + #include "core/framework/stream_handles.h" -#include "migraphx_inc.h" -#include "migraphx_call.h" +#include "core/providers/migraphx/migraphx_inc.h" +#include "core/providers/migraphx/migraphx_call.h" namespace onnxruntime { -void WaitMIGraphXNotificationOnDevice(Stream* stream, synchronize::Notification& notification); struct MIGraphXStream : Stream { MIGraphXStream(hipStream_t stream, @@ -15,7 +18,7 @@ struct MIGraphXStream : Stream { AllocatorPtr cpu_allocator, bool release_cpu_buffer_on_migraphx_stream); - ~MIGraphXStream(); + ~MIGraphXStream() override; std::unique_ptr CreateNotification(size_t /*num_consumers*/) override; @@ -27,7 +30,7 @@ struct MIGraphXStream : Stream { bool own_stream_{true}; - virtual void* GetResource(int version, int id) const; + void* GetResource(int version, int id) const override; private: std::vector deferred_cpu_buffers_; @@ -36,8 +39,8 @@ struct MIGraphXStream : Stream { }; void RegisterMIGraphXStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, - const OrtDevice::DeviceType device_type, - AllocatorPtr cpu_allocator, + OrtDevice::DeviceType device_type, + const AllocatorPtr& cpu_allocator, bool release_cpu_buffer_on_migraphx_stream, hipStream_t external_stream, bool use_existing_stream); diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 71d51c4c2992d..a7fd83f10fe18 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -326,6 +326,12 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe const logging::Logger& logger); std::string GetEnvironmentVar(const std::string& var_name); +inline std::string GetEnvironmentVar(std::string_view var_name) { + return GetEnvironmentVar(std::string{var_name}); +} +inline std::string GetEnvironmentVar(const char* var_name) { + return GetEnvironmentVar(std::string{var_name}); +} namespace profiling { diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 031a4df59d83f..d690cf31072d2 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -790,12 +790,12 @@ Status LoadDynamicLibrary(onnxruntime::PathString library_name) { #endif #ifdef _WIN32 -std::string ToUTF8String(const std::wstring& s) { - return g_host->ToUTF8String(s); +std::string ToUTF8String(std::wstring_view s) { + return g_host->ToUTF8String(std::wstring{s}); } -std::wstring ToWideString(const std::string& s) { - return g_host->ToWideString(s); +std::wstring ToWideString(std::string_view s) { + return g_host->ToWideString(std::string{s}); } #endif // _WIN32 } // namespace onnxruntime diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 4c7b4d7b29c2f..88d84e95b406c 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1378,9 +1378,16 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetOverridableInitializerTypeInfo, _In_ cons return GetNodeDefTypeInfoHelper(sess, get_overridable_initializers_fn, index, out); } -char* onnxruntime::StrDup(const std::string& str, OrtAllocator* allocator) { - char* output_string = reinterpret_cast(allocator->Alloc(allocator, str.size() + 1)); - memcpy(output_string, str.c_str(), str.size()); +char* onnxruntime::StrDup(std::string_view str, OrtAllocator* allocator) { + char* output_string = static_cast(allocator->Alloc(allocator, str.size() + 1)); + memcpy(output_string, str.data(), str.size()); + output_string[str.size()] = '\0'; + return output_string; +} + +wchar_t* onnxruntime::StrDup(std::wstring_view str, OrtAllocator* allocator) { + auto* output_string = static_cast(allocator->Alloc(allocator, (str.size() + 1) * sizeof(wchar_t))); + memcpy(output_string, str.data(), str.size() * sizeof(wchar_t)); output_string[str.size()] = '\0'; return output_string; } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 01b70db6d940e..ee59ff2ab4932 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -2108,8 +2108,13 @@ std::shared_ptr NvProviderFactoryCreator::Create( return nullptr; } -std::shared_ptr MIGraphXProviderFactoryCreator::Create(const OrtMIGraphXProviderOptions* provider_options) { - return s_library_migraphx.Get().CreateExecutionProviderFactory(provider_options); +std::shared_ptr MIGraphXProviderFactoryCreator::Create(const ProviderOptions& provider_options) { + return s_library_migraphx.Get().CreateExecutionProviderFactory(&provider_options); +} + +std::shared_ptr MIGraphXProviderFactoryCreator::Create(const OrtMIGraphXProviderOptions* options) { + const auto provider_options{s_library_migraphx.Get().GetProviderOptions(options)}; + return s_library_migraphx.Get().CreateExecutionProviderFactory(&provider_options); } // Adapter to convert the legacy OrtOpenVINOProviderOptions to ProviderOptions diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 18a463ef69943..48d52ae3cf428 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -101,6 +101,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, VitisAI, CoreML, NvTensorRtRtx, // TensorRt EP for RTX GPUs. + MIGraphX }; struct EpToAppend { @@ -109,7 +110,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, const char* canonical_name = nullptr; }; - static std::array supported_eps = { + static std::array supported_eps = { EpToAppend{EpID::DML, "DML", kDmlExecutionProvider}, EpToAppend{EpID::QNN, "QNN", kQnnExecutionProvider}, EpToAppend{EpID::OpenVINO, "OpenVINO", kOpenVINOExecutionProvider}, @@ -121,7 +122,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, EpToAppend{EpID::JS, "JS", kJsExecutionProvider}, EpToAppend{EpID::VitisAI, "VitisAI", kVitisAIExecutionProvider}, EpToAppend{EpID::CoreML, "CoreML", kCoreMLExecutionProvider}, - EpToAppend{EpID::NvTensorRtRtx, "NvTensorRtRtx", kNvTensorRTRTXExecutionProvider}}; + EpToAppend{EpID::NvTensorRtRtx, "NvTensorRtRtx", kNvTensorRTRTXExecutionProvider}, + EpToAppend{EpID::MIGraphX, "MIGraphX", kMIGraphXExecutionProvider}}; ProviderOptions provider_options; OrtStatus* status = ParseProviderOptions(provider_options_keys, @@ -279,6 +281,14 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options, &(options->value))); #else status = create_not_supported_status(); +#endif + break; + } + case EpID::MIGraphX: { +#if defined(USE_MIGRAPHX) || defined(USE_MIGRAPHX_PROVIDER_INTERFACE) + options->provider_factories.push_back(MIGraphXProviderFactoryCreator::Create(provider_options)); +#else + status = create_not_supported_status(); #endif break; } diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 590e1ef3cdbdb..1934e0eda7956 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -196,9 +196,9 @@ void CudaToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { GetProviderInfo_CUDA().cudaMemcpy_DeviceToHost(dst, src, num_bytes); } -const std::unordered_map* GetCudaToHostMemCpyFunction() { +const std::unordered_map* GetCudaToHostMemCpyFunction(const OrtDevice& device) { static std::unordered_map map{ - {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, 0}, CudaToCpuMemCpy}, + {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, device.Id()}, CudaToCpuMemCpy}, }; return ↦ @@ -246,6 +246,7 @@ std::unique_ptr GetGPUDataTransfer() { #endif #ifdef USE_MIGRAPHX + void CpuToMIGraphXMemCpy(void* dst, const void* src, size_t num_bytes) { GetProviderInfo_MIGraphX().MIGraphXMemcpy_HostToDevice(dst, src, num_bytes); } @@ -256,7 +257,7 @@ void MIGraphXToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { const std::unordered_map* GetMIGraphXToHostMemCpyFunction(const OrtDevice& device) { static std::unordered_map map{ - {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, 0}, MIGraphXToCpuMemCpy}, + {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device.Id()}, MIGraphXToCpuMemCpy}, }; return ↦ @@ -270,7 +271,11 @@ AllocatorPtr GetMIGraphXAllocator(OrtDevice::DeviceId id) { if (id_to_allocator_map->find(id) == id_to_allocator_map->end()) { // TODO: Expose knobs so that users can set fields associated with OrtArenaCfg so that we can pass it to the following method - id_to_allocator_map->insert({id, GetProviderInfo_MIGraphX().CreateMIGraphXAllocator(id, gpu_mem_limit, arena_extend_strategy, migx_external_allocator_info, nullptr)}); + id_to_allocator_map->insert( + {id, GetProviderInfo_MIGraphX().CreateMIGraphXAllocator( + id, gpu_mem_limit, arena_extend_strategy, + migraphx::external::alloc_fn, migraphx::external::free_fn, migraphx::external::empty_cache_fn, + nullptr)}); } return (*id_to_allocator_map)[id]; @@ -374,9 +379,9 @@ void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { D3D12_RESOURCE_STATE_UNORDERED_ACCESS); } -const std::unordered_map* GetDmlToHostMemCpyFunction() { +const std::unordered_map* GetDmlToHostMemCpyFunction(const OrtDevice& device) { static std::unordered_map map{ - {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, 0}, DmlToCpuMemCpy}, + {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, device.Id()}, DmlToCpuMemCpy}, }; return ↦ @@ -444,9 +449,9 @@ void RocmToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { GetProviderInfo_ROCM().rocmMemcpy_DeviceToHost(dst, src, num_bytes); } -const std::unordered_map* GetRocmToHostMemCpyFunction() { +const std::unordered_map* GetRocmToHostMemCpyFunction(const OrtDevice& device) { static std::unordered_map map{ - {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, 0}, RocmToCpuMemCpy}, + {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device.Id()}, RocmToCpuMemCpy}, }; return ↦ diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.h b/onnxruntime/python/onnxruntime_pybind_mlvalue.h index 7b65c0aae45c1..eba783d826212 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.h +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.h @@ -74,7 +74,7 @@ void CpuToCudaMemCpy(void* dst, const void* src, size_t num_bytes); void CudaToCpuMemCpy(void* dst, const void* src, size_t num_bytes); -const std::unordered_map* GetCudaToHostMemCpyFunction(); +const std::unordered_map* GetCudaToHostMemCpyFunction(const OrtDevice&); bool IsCudaDeviceIdValid(const onnxruntime::logging::Logger& logger, int id); @@ -92,7 +92,7 @@ void CpuToDmlMemCpy(void* dst, const void* src, size_t num_bytes); void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes); -const std::unordered_map* GetDmlToHostMemCpyFunction(); +const std::unordered_map* GetDmlToHostMemCpyFunction(const OrtDevice&); #endif @@ -102,7 +102,7 @@ void CpuToMIGraphXMemCpy(void* dst, const void* src, size_t num_bytes); void MIGraphXToCpuMemCpy(void* dst, const void* src, size_t num_bytes); -const std::unordered_map* GetMIGraphXToHostMemCpyFunction(); +const std::unordered_map* GetMIGraphXToHostMemCpyFunction(const OrtDevice&); AllocatorPtr GetMIGraphXAllocator(OrtDevice::DeviceId id); @@ -132,7 +132,7 @@ void CpuToRocmMemCpy(void* dst, const void* src, size_t num_bytes); void RocmToCpuMemCpy(void* dst, const void* src, size_t num_bytes); -const std::unordered_map* GetRocmToHostMemCpyFunction(); +const std::unordered_map* GetRocmToHostMemCpyFunction(const OrtDevice&); #endif diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index 7234543eb14de..1fe7ab0884f9c 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -421,21 +421,39 @@ void addOrtValueMethods(pybind11::module& m) { // Converts Tensor into a numpy array .def("numpy", [](const OrtValue* ml_value) -> py::object { ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are convertible to Numpy objects"); - + [[maybe_unused]] const auto& device = ml_value->Get().Location().device; +#ifdef _MSC_VER +// The switch statement may only contain the 'default' label. In such a case, the MSVC compiler +// will warn about it, and since the warnings are treated as errors, the compilation will break. +// Below pragmas turn off warning generation for this switch only. +#pragma warning(push) +#pragma warning(disable : 4065) +#endif + switch (device.Vendor()) { #ifdef USE_CUDA - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetCudaToHostMemCpyFunction()); -#elif USE_ROCM - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetRocmToHostMemCpyFunction()); -#elif USE_CANN - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetCannToHostMemCpyFunction()); -#elif USE_DML - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetDmlToHostMemCpyFunction()); -#elif USE_MIGRAPHX - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetMIGraphXToHostMemCpyFunction()); -#else - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, nullptr); + case OrtDevice::VendorIds::NVIDIA: + return GetPyObjFromTensor(*ml_value, nullptr, GetCudaToHostMemCpyFunction(device)); +#endif +#ifdef USE_CANN + case OrtDevice::VendorIds::HUAWEI: + return GetPyObjFromTensor(*ml_value, nullptr, GetCannToHostMemCpyFunction()); #endif - return obj; }) + +#ifdef USE_DML + case OrtDevice::VendorIds::MICROSOFT: + return GetPyObjFromTensor(*ml_value, nullptr, GetDmlToHostMemCpyFunction(device)); +#endif +#ifdef USE_MIGRAPHX + case OrtDevice::VendorIds::AMD: + return GetPyObjFromTensor(*ml_value, nullptr, GetMIGraphXToHostMemCpyFunction(device)); +#endif + default: + return GetPyObjFromTensor(*ml_value, nullptr, nullptr); + } +#ifdef _MSC_VER +#pragma warning(pop) +#endif + }) #if defined(ENABLE_DLPACK) .def("to_dlpack", [](OrtValue* ort_value) -> py::object { return py::reinterpret_steal(ToDlpack(*ort_value)); }, "Returns a DLPack representing the tensor. This method does not copy the pointer shape, " diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 03ad0185d1394..24554560b4dde 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -979,135 +979,10 @@ static std::shared_ptr CreateExecutionProviderFactory #endif } else if (type == kMIGraphXExecutionProvider) { #if defined(USE_MIGRAPHX) || defined(USE_MIGRAPHX_PROVIDER_INTERFACE) - std::string calibration_table; - std::string save_model_path; - std::string load_model_path; auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { - OrtMIGraphXProviderOptions params{ - 0, - 0, - 0, - 0, - 0, - nullptr, - 1, - "./compiled_model.mxr", - 1, - "./compiled_model.mxr", - 1, - SIZE_MAX, - 0}; - for (auto option : it->second) { - if (option.first == "device_id") { - if (!option.second.empty()) { - params.device_id = std::stoi(option.second); - } else { - ORT_THROW("[ERROR] [MIGraphX] The value for the key 'device_id' should be a number i.e. '0'.\n"); - } - } else if (option.first == "migraphx_fp16_enable") { - if (option.second == "True" || option.second == "true") { - params.migraphx_fp16_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_fp16_enable = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_fp16_enable' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == "migraphx_fp8_enable") { - if (option.second == "True" || option.second == "true") { - params.migraphx_fp8_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_fp8_enable = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_fp8_enable' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == "migraphx_int8_enable") { - if (option.second == "True" || option.second == "true") { - params.migraphx_int8_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_int8_enable = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_int8_enable' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == "migraphx_int8_calibration_table_name") { - if (!option.second.empty()) { - calibration_table = option.second; - params.migraphx_int8_calibration_table_name = calibration_table.c_str(); - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_int8_calibration_table_name' should be a " - "file name i.e. 'cal_table'.\n"); - } - } else if (option.first == "migraphx_use_native_calibration_table") { - if (option.second == "True" || option.second == "true") { - params.migraphx_use_native_calibration_table = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_use_native_calibration_table = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_use_native_calibration_table' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == "migraphx_save_compiled_model") { - if (option.second == "True" || option.second == "true") { - params.migraphx_fp16_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_fp16_enable = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_save_compiled_model' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == "migraphx_save_model_path") { - if (!option.second.empty()) { - save_model_path = option.second; - params.migraphx_save_model_path = save_model_path.c_str(); - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_save_model_name' should be a " - "file name i.e. 'compiled_model.mxr'.\n"); - } - } else if (option.first == "migraphx_load_compiled_model") { - if (option.second == "True" || option.second == "true") { - params.migraphx_fp16_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_fp16_enable = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_load_compiled_model' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == "migraphx_load_model_path") { - if (!option.second.empty()) { - load_model_path = option.second; - params.migraphx_load_model_path = load_model_path.c_str(); - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_load_model_name' should be a " - "file name i.e. 'compiled_model.mxr'.\n"); - } - } else if (option.first == "migraphx_exhaustive_tune") { - if (option.second == "True" || option.second == "true") { - params.migraphx_exhaustive_tune = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_exhaustive_tune = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_exhaustive_tune' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else { - ORT_THROW("Invalid MIGraphX EP option: ", option.first); - } - } if (std::shared_ptr migraphx_provider_factory = - onnxruntime::MIGraphXProviderFactoryCreator::Create(¶ms)) { + onnxruntime::MIGraphXProviderFactoryCreator::Create(it->second)) { return migraphx_provider_factory; } } else { @@ -1917,7 +1792,6 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra vendor = OrtDevice::VendorIds::HUAWEI; #endif } - return OrtDevice(type, mem_type, vendor, device_id); }), R"pbdoc(Constructor with vendor_id defaulted to 0 for backward compatibility.)pbdoc") diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.cc b/onnxruntime/python/onnxruntime_pybind_state_common.cc index 4b9e012764885..cccdb9d23900a 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.cc +++ b/onnxruntime/python/onnxruntime_pybind_state_common.cc @@ -47,7 +47,11 @@ onnxruntime::ArenaExtendStrategy arena_extend_strategy = onnxruntime::ArenaExten #endif #ifdef USE_MIGRAPHX -onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo migx_external_allocator_info{}; +namespace migraphx::external { +void* alloc_fn{nullptr}; +void* free_fn{nullptr}; +void* empty_cache_fn{nullptr}; +} // namespace migraphx::external #endif #if defined(ENABLE_DLPACK) diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 706c151936192..b4a33e798f942 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -40,7 +40,7 @@ struct OrtStatus { #define BACKEND_PROC "CPU" #endif -#if USE_DNNL +#ifdef USE_DNNL #define BACKEND_DNNL "-DNNL" #else #define BACKEND_DNNL "" @@ -226,9 +226,14 @@ extern onnxruntime::ArenaExtendStrategy arena_extend_strategy; namespace onnxruntime { ProviderInfo_MIGraphX* TryGetProviderInfo_MIGraphX(); ProviderInfo_MIGraphX& GetProviderInfo_MIGraphX(); -namespace python { -extern onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo migx_external_allocator_info; -} // namespace python +namespace python::migraphx::external { +extern void* alloc_fn; +extern void* free_fn; +extern void* empty_cache_fn; +inline bool UseExternalAllocator() { + return alloc_fn != nullptr && free_fn != nullptr; +} +} // namespace python::migraphx::external } // namespace onnxruntime #endif diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 2e4aa3923b649..bae7a14908916 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -80,21 +80,7 @@ std::unique_ptr TensorrtExecutionProviderWithOptions(const O std::unique_ptr DefaultMIGraphXExecutionProvider() { #ifdef USE_MIGRAPHX - OrtMIGraphXProviderOptions params{ - 0, - 0, - 0, - 0, - 0, - nullptr, - 1, - "./compiled_model.mxr", - 1, - "./compiled_model.mxr", - 1, - SIZE_MAX, - 0}; - return MIGraphXProviderFactoryCreator::Create(¶ms)->CreateProvider(); + return MIGraphXProviderFactoryCreator::Create(ProviderOptions{})->CreateProvider(); #else return nullptr; #endif @@ -102,7 +88,7 @@ std::unique_ptr DefaultMIGraphXExecutionProvider() { std::unique_ptr MIGraphXExecutionProviderWithOptions(const OrtMIGraphXProviderOptions* params) { #ifdef USE_MIGRAPHX - if (auto factory = MIGraphXProviderFactoryCreator::Create(params)) + if (const auto factory = MIGraphXProviderFactoryCreator::Create(params); factory != nullptr) return factory->CreateProvider(); #else ORT_UNUSED_PARAMETER(params); diff --git a/setup.py b/setup.py index 5ab1ac5b840d4..6bfb53329f319 100644 --- a/setup.py +++ b/setup.py @@ -68,6 +68,7 @@ def parse_arg_remove_string(argv, arg_name_equal): is_cuda_version_12 = cuda_version.startswith("12.") elif parse_arg_remove_boolean(sys.argv, "--use_migraphx"): is_migraphx = True + package_name = "onnxruntime-migraphx" elif parse_arg_remove_boolean(sys.argv, "--use_openvino"): is_openvino = True package_name = "onnxruntime-openvino" @@ -90,8 +91,6 @@ def parse_arg_remove_string(argv, arg_name_equal): is_qnn = True package_name = "onnxruntime-qnn" qnn_version = parse_arg_remove_string(sys.argv, "--qnn_version=") -elif is_migraphx: - package_name = "onnxruntime-migraphx" if not nightly_build else "ort-migraphx-nightly" # PEP 513 defined manylinux1_x86_64 and manylinux1_i686 # PEP 571 defined manylinux2010_x86_64 and manylinux2010_i686 @@ -283,7 +282,6 @@ def run(self): self._rewrite_ld_preload_tensorrt(to_preload_tensorrt) self._rewrite_ld_preload_tensorrt(to_preload_nv_tensorrt_rtx) self._rewrite_ld_preload(to_preload_cann) - else: pass @@ -412,6 +410,7 @@ def finalize_options(self): libs.extend(["onnxruntime_providers_nv_tensorrt_rtx.dll"]) libs.extend(["onnxruntime_providers_openvino.dll"]) libs.extend(["onnxruntime_providers_cuda.dll"]) + libs.extend(["onnxruntime_providers_migraphx.dll"]) libs.extend(["onnxruntime_providers_vitisai.dll"]) libs.extend(["onnxruntime_providers_qnn.dll"]) # DirectML Libs @@ -435,6 +434,26 @@ def finalize_options(self): libs.extend(qnn_deps) if nightly_build: libs.extend(["onnxruntime_pywrapper.dll"]) + migraphx_deps = [ + "amd_comgr0602.dll", + "amd_comgr0604.dll", + "amd_comgr0700.dll", + "hiprtc0602.dll", + "hiprtc0604.dll", + "hiprtc0700.dll", + "hiprtc-builtins0602.dll", + "hiprtc-builtins0604.dll", + "hiprtc-builtins0700.dll", + "migraphx-hiprtc-driver.exe", + "migraphx.dll", + "migraphx_c.dll", + "migraphx_cpu.dll", + "migraphx_device.dll", + "migraphx_gpu.dll", + "migraphx_onnx.dll", + "migraphx_tf.dll", + ] + libs.extend(migraphx_deps) if is_manylinux: if is_openvino: diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 561a76be5fa89..0d51f66df33aa 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -723,8 +723,6 @@ def generate_build_tree( cmake_args += ["-Donnxruntime_ENABLE_WEBASSEMBLY_RELAXED_SIMD=ON"] if args.use_migraphx: cmake_args.append("-Donnxruntime_MIGRAPHX_HOME=" + migraphx_home) - cmake_args.append("-Donnxruntime_ROCM_HOME=" + rocm_home) - cmake_args.append("-Donnxruntime_ROCM_VERSION=" + args.rocm_version) if args.use_tensorrt: cmake_args.append("-Donnxruntime_TENSORRT_HOME=" + tensorrt_home) @@ -1994,6 +1992,7 @@ def build_nuget_package( use_winml, use_qnn, use_dml, + use_migraphx, enable_training_apis, msbuild_extra_options, ): @@ -2031,6 +2030,9 @@ def build_nuget_package( elif use_tensorrt: execution_provider = "/p:ExecutionProvider=tensorrt" package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.TensorRT" + elif use_migraphx: + execution_provider = "/p:ExecutionProvider=migraphx" + package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.MIGraphX" elif use_dnnl: execution_provider = "/p:ExecutionProvider=dnnl" package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.DNNL" @@ -2622,6 +2624,7 @@ def main(): getattr(args, "use_winml", False), args.use_qnn, getattr(args, "use_dml", False), + args.use_migraphx, args.enable_training_apis, normalize_arg_list(args.msbuild_extra_options), ) diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index 211cb7a2a8a75..ead240a7cef1b 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -22,6 +22,8 @@ def get_package_name(os, cpu_arch, ep, is_training_package): pkg_name += "-tensorrt" elif ep == "rocm": pkg_name += "-rocm" + elif ep == "migraphx": + pkg_name += "-migraphx" elif os == "linux": pkg_name += "-linux-" pkg_name += cpu_arch @@ -31,6 +33,8 @@ def get_package_name(os, cpu_arch, ep, is_training_package): pkg_name += "-tensorrt" elif ep == "rocm": pkg_name += "-rocm" + elif ep == "migraphx": + pkg_name += "-migraphx" elif os == "osx": pkg_name = "onnxruntime-osx-" + cpu_arch return pkg_name @@ -44,7 +48,11 @@ def get_package_name(os, cpu_arch, ep, is_training_package): def is_this_file_needed(ep, filename, package_name): if package_name == "Microsoft.ML.OnnxRuntime.Gpu": return False - return (ep != "cuda" or "cuda" in filename) and (ep != "tensorrt" or "cuda" not in filename) + return ( + (ep != "cuda" or "cuda" in filename) + and (ep != "tensorrt" or "cuda" not in filename) + and (ep != "migraphx" or "migraphx" not in filename) + ) # nuget_artifacts_dir: the directory with uncompressed C API tarball/zip files @@ -138,7 +146,7 @@ def parse_arguments(): required=False, default="None", type=str, - choices=["cuda", "dnnl", "openvino", "tensorrt", "snpe", "qnn", "None"], + choices=["cuda", "dnnl", "openvino", "migraphx", "tensorrt", "snpe", "qnn", "None"], help="The selected execution provider for this build.", ) parser.add_argument("--sdk_info", required=False, default="", type=str, help="dependency SDK information.") @@ -182,6 +190,8 @@ def generate_description(line_list, package_name): description = "This package contains Linux native shared library artifacts for ONNX Runtime with CUDA." elif "Microsoft.ML.OnnxRuntime.Gpu.Windows" in package_name: description = "This package contains Windows native shared library artifacts for ONNX Runtime with CUDA." + elif "Microsoft.ML.OnnxRuntime.MIGraphX" in package_name: + description = "This package contains native shared library artifacts for ONNX Runtime with MIGraphX." elif "Intel.ML.OnnxRuntime" in package_name: description = "This package contains native shared library artifacts for ONNX Runtime with OpenVINO." elif "Microsoft.ML.OnnxRuntime" in package_name: # This is a Microsoft.ML.OnnxRuntime.* package @@ -359,6 +369,7 @@ def generate_files(line_list, args): is_windowsai_package = args.package_name == "Microsoft.AI.MachineLearning" is_snpe_package = args.package_name == "Microsoft.ML.OnnxRuntime.Snpe" is_qnn_package = args.package_name == "Microsoft.ML.OnnxRuntime.QNN" + is_migraphx_package = args.package_name == "Microsoft.ML.OnnxRuntime.MIGraphX" is_training_package = args.package_name in [ "Microsoft.ML.OnnxRuntime.Training", "Microsoft.ML.OnnxRuntime.Training.Gpu", @@ -384,6 +395,24 @@ def generate_files(line_list, args): "openvino_ep_shared_lib": "onnxruntime_providers_openvino.dll", "cuda_ep_shared_lib": "onnxruntime_providers_cuda.dll", "qnn_ep_shared_lib": "onnxruntime_providers_qnn.dll", + "migraphx_ep_shared_lib": "onnxruntime_providers_migraphx.dll", + "amd_comgr0602": "amd_comgr0602.dll", + "amd_comgr0604": "amd_comgr0604.dll", + "amd_comgr0700": "amd_comgr0700.dll", + "hiprtc0602": "hiprtc0602.dll", + "hiprtc0604": "hiprtc0604.dll", + "hiprtc0700": "hiprtc0700.dll", + "hiprtc-builtins0602": "hiprtc-builtins0602.dll", + "hiprtc-builtins0604": "hiprtc-builtins0604.dll", + "hiprtc-builtins0700": "hiprtc-builtins0700.dll", + "migraphx-hiprtc-driver": "migraphx-hiprtc-driver.exe", + "migraphx": "migraphx.dll", + "migraphx_c": "migraphx_c.dll", + "migraphx_cpu": "migraphx_cpu.dll", + "migraphx_device": "migraphx_device.dll", + "migraphx_gpu": "migraphx_gpu.dll", + "migraphx_onnx": "migraphx_onnx.dll", + "migraphx_tf": "migraphx_tf", "onnxruntime_perf_test": "onnxruntime_perf_test.exe", "onnx_test_runner": "onnx_test_runner.exe", } @@ -402,6 +431,7 @@ def generate_files(line_list, args): "openvino_ep_shared_lib": "libonnxruntime_providers_openvino.so", "cuda_ep_shared_lib": "libonnxruntime_providers_cuda.so", "rocm_ep_shared_lib": "libonnxruntime_providers_rocm.so", + "migraphx_ep_shared_lib": "libonnxruntime_providers_migraphx.so", "onnxruntime_perf_test": "onnxruntime_perf_test", "onnx_test_runner": "onnx_test_runner", } @@ -421,7 +451,7 @@ def generate_files(line_list, args): include_dir = f"{build_dir}\\native\\include" # Sub.Gpu packages do not include the onnxruntime headers - if args.package_name != "Microsoft.ML.OnnxRuntime.Gpu": + if args.package_name != "Microsoft.ML.OnnxRuntime.Gpu" and args.package_name != "Microsoft.ML.OnnxRuntime.MIGraphX": files_list.append( "' ) + if args.execution_provider == "migraphx": + files_list.append( + "' + ) + files_list.append( + "' + ) + + if is_windows_build: + native_build_path = Path(args.native_build_path) + + def _files_list_append(key: str): + path = native_build_path / nuget_dependencies[key] + if path.exists(): + files_list.append( + "' + ) + + _files_list_append("amd_comgr0602") + _files_list_append("amd_comgr0604") + _files_list_append("amd_comgr0700") + _files_list_append("hiprtc0602") + _files_list_append("hiprtc0604") + _files_list_append("hiprtc0700") + _files_list_append("hiprtc-builtins0602") + _files_list_append("hiprtc-builtins0604") + _files_list_append("hiprtc-builtins0700") + _files_list_append("migraphx-hiprtc-driver") + _files_list_append("migraphx") + _files_list_append("migraphx_c") + _files_list_append("migraphx_cpu") + _files_list_append("migraphx_device") + _files_list_append("migraphx_gpu") + _files_list_append("migraphx_onnx") + _files_list_append("migraphx_tf") + if is_dml_package: files_list.append( "