diff --git a/.gitignore b/.gitignore index 4d0a1205b7c19..b25763334f227 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ # build, distribute, and bins (+ python proto bindings) +build.*/ build build_*/ .build_debug/* diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index b0941b4d0c922..4a3ca3dcd7741 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -29,6 +29,7 @@ include(CheckLanguage) include(CMakeDependentOption) include(FetchContent) include(CheckFunctionExists) +include(CheckSymbolExists) include(GNUInstallDirs) # onnxruntime_providers_* require CMAKE_INSTALL_* variables # TODO: update this once all system adapt c++20 diff --git a/cmake/onnxruntime_providers_migraphx.cmake b/cmake/onnxruntime_providers_migraphx.cmake index 495ff093326ad..8cb5dcf95155a 100644 --- a/cmake/onnxruntime_providers_migraphx.cmake +++ b/cmake/onnxruntime_providers_migraphx.cmake @@ -2,21 +2,11 @@ # Licensed under the MIT License. add_definitions(-DUSE_MIGRAPHX=1) - set(BUILD_LIBRARY_ONLY 1) - add_definitions("-DONNX_ML=1") - add_definitions("-DONNX_NAMESPACE=onnx") - include_directories(${protobuf_SOURCE_DIR} ${eigen_SOURCE_DIR}) - set(MIGRAPHX_ROOT ${onnxruntime_MIGRAPHX_HOME}) - include_directories(${onnx_SOURCE_DIR}) + include_directories(${protobuf_SOURCE_DIR} ${eigen_SOURCE_DIR} ${onnx_SOURCE_DIR}) set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) - if ( CMAKE_COMPILER_IS_GNUCC ) + if (CMAKE_COMPILER_IS_GNUCC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter -Wno-missing-field-initializers") endif() - set(CXX_VERSION_DEFINED TRUE) - set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS}) - if ( CMAKE_COMPILER_IS_GNUCC ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") - endif() # Add search paths for default rocm installation list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hcc /opt/rocm/hip /opt/rocm $ENV{HIP_PATH}) @@ -33,8 +23,6 @@ find_package(hip REQUIRED) find_package(migraphx REQUIRED PATHS ${AMD_MIGRAPHX_HOME}) - set(migraphx_libs migraphx::c hip::host) - file(GLOB_RECURSE onnxruntime_providers_migraphx_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.h" "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.cc" @@ -42,14 +30,14 @@ "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" ) source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_migraphx_cc_srcs}) - onnxruntime_add_shared_library_module(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs}) + onnxruntime_add_shared_library(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) - add_dependencies(onnxruntime_providers_migraphx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) - target_link_libraries(onnxruntime_providers_migraphx PRIVATE ${migraphx_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) - target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime) + add_dependencies(onnxruntime_providers_migraphx ${onnxruntime_EXTERNAL_DEPENDENCIES}) + target_link_libraries(onnxruntime_providers_migraphx PRIVATE migraphx::c hip::host ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) + target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/migraphx/onnxruntime) set_target_properties(onnxruntime_providers_migraphx PROPERTIES LINKER_LANGUAGE CXX) set_target_properties(onnxruntime_providers_migraphx PROPERTIES FOLDER "ONNXRuntime") - target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1) + target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1 ONNX_ML=1 ONNX_NAMESPACE=onnx) if(MSVC) set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS /DEF:${ONNXRUNTIME_ROOT}/core/providers/migraphx/symbols.def) target_link_libraries(onnxruntime_providers_migraphx PRIVATE ws2_32) @@ -62,6 +50,15 @@ target_link_libraries(onnxruntime_providers_migraphx PRIVATE stdc++fs) endif() + set(CMAKE_REQUIRED_LIBRARIES migraphx::c) + + check_symbol_exists(migraphx_onnx_options_set_external_data_path + "migraphx/migraphx.h" HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH) + + if(HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH) + target_compile_definitions(onnxruntime_providers_migraphx PRIVATE HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH=1) + endif() + if (onnxruntime_ENABLE_TRAINING_OPS) onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_training) target_link_libraries(onnxruntime_providers_migraphx PRIVATE onnxruntime_training) @@ -71,15 +68,39 @@ endif() if(CMAKE_SYSTEM_NAME STREQUAL "Windows") - install(TARGETS onnxruntime_providers_migraphx - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - ) - else() - install(TARGETS onnxruntime_providers_migraphx - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - ) + foreach(file migraphx-hiprtc-driver.exe migraphx.dll migraphx_c.dll migraphx_cpu.dll migraphx_device.dll migraphx_gpu.dll migraphx_onnx.dll migraphx_tf.dll) + set(_source "${AMD_MIGRAPHX_HOME}/bin/${file}") + if(EXISTS "${_source}") + add_custom_command(TARGET onnxruntime_providers_migraphx + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${_source} $) + set(_target "$/${file}") + list(APPEND _migraphx_targets ${_target}) + endif() + endforeach() + set(MIGRAPHX_LIB_FILES ${_migraphx_targets} CACHE INTERNAL "" FORCE) + install(FILES ${_migraphx_targets} + DESTINATION ${CMAKE_INSTALL_BINDIR}) + get_property(_amdhip64_location TARGET hip::amdhip64 PROPERTY IMPORTED_LOCATION_RELEASE) + cmake_path(GET _amdhip64_location PARENT_PATH _hipsdk_path) + foreach(file amd_comgr0602.dll amd_comgr0604.dll amd_comgr0700.dll hiprtc0602.dll hiprtc0604.dll hiprtc0700.dll hiprtc-builtins0602.dll hiprtc-builtins0604.dll hiprtc-builtins0700.dll) + set(_source "${_hipsdk_path}/${file}") + if(EXISTS "${_source}") + add_custom_command(TARGET onnxruntime_providers_migraphx + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${_source} $) + set(_target "$/${file}") + list(APPEND _hipsdk_targets ${_target}) + endif() + endforeach() + set(HIPSDK_LIB_FILES ${_hipsdk_targets} CACHE INTERNAL "" FORCE) + install(FILES ${_hipsdk_targets} + DESTINATION ${CMAKE_INSTALL_BINDIR}) endif() + + install(TARGETS onnxruntime_providers_migraphx + EXPORT onnxruntime_providers_migraphxTargets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index c5c85dff96411..ae976abe62fd8 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -740,6 +740,21 @@ if (onnxruntime_USE_OPENVINO) ) endif() +if (onnxruntime_USE_MIGRAPHX) + if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${MIGRAPHX_LIB_FILES} + $/onnxruntime/capi/) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${HIPSDK_LIB_FILES} + $/onnxruntime/capi/) + endif() +endif() + if (onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 36ba8db9bdc75..a1e9d06b133fd 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -607,7 +607,6 @@ endif() if(onnxruntime_USE_MIGRAPHX) list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_migraphx) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx onnxruntime_providers_shared) endif() if(onnxruntime_USE_COREML) @@ -688,9 +687,6 @@ endif() if(onnxruntime_USE_MIGRAPHX) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/migraphx/*) - list(APPEND onnxruntime_test_framework_src_patterns "${ONNXRUNTIME_ROOT}/core/providers/migraphx/migraphx_execution_provider_utils.h") - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_migraphx) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx onnxruntime_providers_shared) endif() if(onnxruntime_USE_NNAPI_BUILTIN) diff --git a/include/onnxruntime/core/common/common.h b/include/onnxruntime/core/common/common.h index adfd341451aed..820d140ccaabc 100644 --- a/include/onnxruntime/core/common/common.h +++ b/include/onnxruntime/core/common/common.h @@ -294,12 +294,26 @@ inline std::string ToUTF8String(const std::string& s) { return s; } /** * Convert a wide character string to a UTF-8 string */ -std::string ToUTF8String(const std::wstring& s); - -std::wstring ToWideString(const std::string& s); +std::string ToUTF8String(std::wstring_view s); +inline std::string ToUTF8String(const wchar_t* s) { + return ToUTF8String(std::wstring_view{s}); +} +inline std::string ToUTF8String(const std::wstring& s) { + return ToUTF8String(std::wstring_view{s}); +} +std::wstring ToWideString(std::string_view s); +inline std::wstring ToWideString(const char* s) { + return ToWideString(std::string_view{s}); +} +inline std::wstring ToWideString(const std::string& s) { + return ToWideString(std::string_view{s}); +} inline std::wstring ToWideString(const std::wstring& s) { return s; } +inline std::wstring ToWideString(std::wstring_view s) { return std::wstring{s}; } #else inline std::string ToWideString(const std::string& s) { return s; } +inline std::string ToWideString(const char* s) { return s; } +inline std::string ToWideString(std::string_view s) { return std::string{s}; } #endif constexpr size_t kMaxStrLen = 4096; diff --git a/include/onnxruntime/core/common/string_helper.h b/include/onnxruntime/core/common/string_helper.h index 1304303132d5a..c0b331cb8e9a8 100644 --- a/include/onnxruntime/core/common/string_helper.h +++ b/include/onnxruntime/core/common/string_helper.h @@ -7,5 +7,9 @@ // forward declaration struct OrtAllocator; namespace onnxruntime { -char* StrDup(const std::string& str, OrtAllocator* allocator); +char* StrDup(std::string_view str, OrtAllocator* allocator); +inline char* StrDup(const std::string& str, OrtAllocator* allocator) { + return StrDup(std::string_view{str}, allocator); +} +wchar_t* StrDup(std::wstring_view str, OrtAllocator* allocator); } // namespace onnxruntime diff --git a/include/onnxruntime/core/framework/provider_options_utils.h b/include/onnxruntime/core/framework/provider_options_utils.h index 5967fb91523d0..badb7320ea49e 100644 --- a/include/onnxruntime/core/framework/provider_options_utils.h +++ b/include/onnxruntime/core/framework/provider_options_utils.h @@ -83,12 +83,24 @@ class ProviderOptionsParser { template ProviderOptionsParser& AddValueParser( const std::string& name, ValueParserType value_parser) { + return AddValueParser(std::string_view{name}, value_parser); + } + + template + ProviderOptionsParser& AddValueParser( + std::string_view name, ValueParserType value_parser) { ORT_ENFORCE( value_parsers_.emplace(name, ValueParser{value_parser}).second, "Provider option \"", name, "\" already has a value parser."); return *this; } + template + ProviderOptionsParser& AddValueParser( + const char* name, ValueParserType value_parser) { + return AddValueParser(std::string_view{name}, value_parser); + } + /** * Adds a parser for a particular provider option value which converts a * value to the right type and assigns it to the given reference. @@ -104,13 +116,25 @@ class ProviderOptionsParser { template ProviderOptionsParser& AddAssignmentToReference( const std::string& name, ValueType& dest) { + return AddAssignmentToReference(std::string_view{name}, dest); + } + + template + ProviderOptionsParser& AddAssignmentToReference( + std::string_view name, ValueType& dest) { return AddValueParser( name, - [&dest](const std::string& value_str) -> Status { + [&dest](std::string_view value_str) -> Status { return ParseStringWithClassicLocale(value_str, dest); }); } + template + ProviderOptionsParser& AddAssignmentToReference( + const char* name, ValueType& dest) { + return AddAssignmentToReference(std::string_view{name}, dest); + } + /** * Adds a parser for a particular provider option value which maps an * enumeration name to a value and assigns it to the given reference. @@ -128,6 +152,12 @@ class ProviderOptionsParser { template ProviderOptionsParser& AddAssignmentToEnumReference( const std::string& name, const EnumNameMapping& mapping, EnumType& dest) { + return AddAssignmentToEnumReference(std::string_view{name}, mapping, dest); + } + + template + ProviderOptionsParser& AddAssignmentToEnumReference( + std::string_view name, const EnumNameMapping& mapping, EnumType& dest) { return AddValueParser( name, [&mapping, &dest](const std::string& value_str) -> Status { @@ -135,6 +165,12 @@ class ProviderOptionsParser { }); } + template + ProviderOptionsParser& AddAssignmentToEnumReference( + const char* name, const EnumNameMapping& mapping, EnumType& dest) { + return AddAssignmentToEnumReference(std::string_view{name}, mapping, dest); + } + /** * Parses the given provider options. */ diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 2899a219bdda0..6eb15280a4aa4 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -754,13 +754,13 @@ typedef struct OrtMIGraphXProviderOptions { int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true int migraphx_fp8_enable; // MIGraphX FP8 precision. Default 0 = false, nonzero = true int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true - int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true + int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, nonzero = true const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name - int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, noznero = true + int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, nonzero = true const char* migraphx_save_model_path; // migraphx model path name - int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true + int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, nonzero = true const char* migraphx_load_model_path; // migraphx model path name - bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false + bool migraphx_exhaustive_tune; // MIGraphX tuned compile. Default = false, nonzero = true /** \brief MIGraphX memory limit (To use all possible memory pass in maximum size_t) * Defaults to SIZE_MAX. @@ -776,6 +776,7 @@ typedef struct OrtMIGraphXProviderOptions { */ int migraphx_arena_extend_strategy; + // This is the legacy struct and don't add new fields here. } OrtMIGraphXProviderOptions; /** \brief OpenVINO Provider Options diff --git a/onnxruntime/core/common/helper.cc b/onnxruntime/core/common/helper.cc index 6a52db73df106..07cd1672b27c1 100644 --- a/onnxruntime/core/common/helper.cc +++ b/onnxruntime/core/common/helper.cc @@ -18,7 +18,7 @@ namespace onnxruntime { #ifdef _WIN32 -std::string ToUTF8String(const std::wstring& s) { +std::string ToUTF8String(std::wstring_view s) { if (s.size() >= static_cast(std::numeric_limits::max())) ORT_THROW("length overflow"); @@ -33,7 +33,7 @@ std::string ToUTF8String(const std::wstring& s) { return ret; } -std::wstring ToWideString(const std::string& s) { +std::wstring ToWideString(std::string_view s) { if (s.size() >= static_cast(std::numeric_limits::max())) ORT_THROW("length overflow"); diff --git a/onnxruntime/core/common/path_string.h b/onnxruntime/core/common/path_string.h index 6cfb327cce08a..4ca326d76a37d 100644 --- a/onnxruntime/core/common/path_string.h +++ b/onnxruntime/core/common/path_string.h @@ -40,6 +40,12 @@ inline PathString ToPathString(const PathString& s) { static_assert(std::is_same::value, "PathString is not std::wstring!"); +inline PathString ToPathString(std::string_view s) { + return ToWideString(s); +} +inline PathString ToPathString(const char* s) { + return ToWideString(s); +} inline PathString ToPathString(const std::string& s) { return ToWideString(s); } @@ -56,6 +62,14 @@ inline std::string PathToUTF8String(const PathString& s) { static_assert(std::is_same::value, "PathString is not std::string!"); +inline PathString ToPathString(const char* s) { + return s; +} + +inline PathString ToPathString(std::string_view s) { + return PathString{s}; +} + inline PathChar ToLowerPathChar(PathChar c) { return std::tolower(c); } diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.h b/onnxruntime/core/providers/migraphx/gpu_data_transfer.h index 5918716b3e77f..a4eb8efd2afea 100644 --- a/onnxruntime/core/providers/migraphx/gpu_data_transfer.h +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.h @@ -3,7 +3,7 @@ #pragma once -#include "migraphx_inc.h" +#include "core/providers/migraphx/migraphx_inc.h" #include "core/framework/data_transfer.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc index 1cac133ab0c2c..911a1a7fd18b9 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc @@ -2,12 +2,11 @@ // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" -#include "migraphx_call.h" -#include "migraphx_allocator.h" +#include "core/providers/migraphx/migraphx_call.h" +#include "core/providers/migraphx/migraphx_allocator.h" #include "core/common/status.h" #include "core/framework/float16.h" -#include "core/common/status.h" -#include "gpu_data_transfer.h" +#include "core/providers/migraphx/gpu_data_transfer.h" namespace onnxruntime { @@ -55,7 +54,9 @@ void MIGraphXExternalAllocator::Free(void* p) { auto it = reserved_.find(p); if (it != reserved_.end()) { reserved_.erase(it); - if (empty_cache_) empty_cache_(); + if (empty_cache_ != nullptr) { + empty_cache_(); + } } } diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.h b/onnxruntime/core/providers/migraphx/migraphx_allocator.h index f6b7788e0604c..10e06ab2f35ad 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.h +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.h @@ -3,9 +3,9 @@ #pragma once +#include #include #include "core/framework/allocator.h" -#include namespace onnxruntime { diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.cc b/onnxruntime/core/providers/migraphx/migraphx_call.cc index 9807cd646e51c..79dfb5512d3b5 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_call.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_call.cc @@ -1,13 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include + #ifdef _WIN32 #include #else #include #endif -#include #include "core/common/common.h" #include "core/common/status.h" #include "core/providers/shared_library/provider_api.h" @@ -15,10 +17,9 @@ namespace onnxruntime { -using namespace common; - +namespace { template -const char* RocmErrString(ERRTYPE x) { +std::string_view RocmErrString(ERRTYPE x) { ORT_NOT_IMPLEMENTED(); } @@ -27,14 +28,16 @@ const char* RocmErrString(ERRTYPE x) { return #x template <> -const char* RocmErrString(hipError_t x) { +std::string_view RocmErrString(hipError_t x) { (void)hipDeviceSynchronize(); - return hipGetErrorString(x); + return std::string_view{hipGetErrorString(x)}; } +} // namespace + template std::conditional_t RocmCall( - ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line) { + ERRTYPE retCode, std::string_view exprString, std::string_view libName, ERRTYPE successCode, std::string_view msg, std::string_view file, const int line) { if (retCode != successCode) { try { #ifdef _WIN32 @@ -47,17 +50,16 @@ std::conditional_t RocmCall( int currentHipDevice; (void)hipGetDevice(¤tHipDevice); (void)hipGetLastError(); // clear last HIP error - static char str[1024]; - snprintf(str, 1024, "%s failure %d: %s ; GPU=%d ; hostname=%s ; file=%s ; line=%d ; expr=%s; %s", - libName, (int)retCode, RocmErrString(retCode), currentHipDevice, - hostname.c_str(), - file, line, exprString, msg); + std::stringstream ss; + ss << libName << " failure " << static_cast(retCode) << ": " << RocmErrString(retCode) + << "; GPU=" << currentHipDevice << "; hostname=" << hostname << "; file=" << file << "; line=" << line + << "; expr=" << exprString << "; " << msg; if constexpr (THRW) { // throw an exception with the error info - ORT_THROW(str); + ORT_THROW(ss.str()); } else { - LOGS_DEFAULT(ERROR) << str; - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, str); + LOGS_DEFAULT(ERROR) << ss.str(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ss.str()); } } catch (const std::exception& e) { // catch, log, and rethrow since HIP code sometimes hangs in destruction, so we'd never get to see the error if constexpr (THRW) { @@ -73,7 +75,7 @@ std::conditional_t RocmCall( } } -template Status RocmCall(hipError_t retCode, const char* exprString, const char* libName, hipError_t successCode, const char* msg, const char* file, const int line); -template void RocmCall(hipError_t retCode, const char* exprString, const char* libName, hipError_t successCode, const char* msg, const char* file, const int line); +template Status RocmCall(hipError_t retCode, std::string_view exprString, std::string_view libName, hipError_t successCode, std::string_view msg, std::string_view file, int line); +template void RocmCall(hipError_t retCode, std::string_view exprString, std::string_view libName, hipError_t successCode, std::string_view msg, std::string_view file, int line); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.h b/onnxruntime/core/providers/migraphx/migraphx_call.h index 6d514e01aea96..9c3b5c79a947b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_call.h +++ b/onnxruntime/core/providers/migraphx/migraphx_call.h @@ -2,7 +2,7 @@ // Licensed under the MIT License. #pragma once -#include "migraphx_inc.h" +#include "core/providers/migraphx/migraphx_inc.h" #include "core/common/common.h" namespace onnxruntime { @@ -13,7 +13,7 @@ namespace onnxruntime { template std::conditional_t RocmCall( - ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line); + ERRTYPE retCode, std::string_view exprString, std::string_view libName, ERRTYPE successCode, std::string_view msg, std::string_view file, int line); #define HIP_CALL(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) #define HIP_CALL_THROW(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 41b55e3baf508..a59347841be95 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1,26 +1,34 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License -#include + +#include + #include +#include +#include +#include #include -#include +#include +#include #include -#include +#include +#include +#include +#include +#include #include "core/providers/shared_library/provider_api.h" #define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" #include "core/common/safeint.h" #include "core/common/logging/severity.h" -#include "migraphx_execution_provider.h" -#include "migraphx_execution_provider_info.h" -#include "migraphx_execution_provider_utils.h" -#include "migraphx_allocator.h" -#include "gpu_data_transfer.h" -#include -#include "migraphx_call.h" - -#include "migraphx_stream_handle.h" +#include "core/providers/migraphx/migraphx_execution_provider.h" +#include "core/providers/migraphx/migraphx_execution_provider_info.h" +#include "core/providers/migraphx/migraphx_execution_provider_utils.h" +#include "core/providers/migraphx/migraphx_allocator.h" +#include "core/providers/migraphx/gpu_data_transfer.h" +#include "core/providers/migraphx/migraphx_call.h" +#include "core/providers/migraphx/migraphx_stream_handle.h" #if defined(_MSC_VER) #pragma warning(disable : 4244 4245) @@ -105,240 +113,144 @@ std::shared_ptr MIGraphXExecutionProvider::GetKernelRegistry() c return s_kernel_registry; } -MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, - info.device_id)}, - info_(info) { - InitProviderOrtApi(); - get_flags_from_session_info(info); - metadef_id_generator_ = ModelMetadefIdGenerator::Create(); - get_flags_from_env(); +static std::string_view GetArenaExtendStrategyName(ArenaExtendStrategy strategy) { + switch (strategy) { + case ArenaExtendStrategy::kNextPowerOfTwo: + return "kNextPowerOfTwo"; + case ArenaExtendStrategy::kSameAsRequested: + return "kSameAsRequested"; + default: + return "Unknown"; + } } -MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { -} +#define GET_ENV(variable, value, ...) \ + const auto value##env{GetEnvironmentVar(variable)}; \ + if (!value##env.empty()) { \ + __VA_ARGS__; \ + LOGS_DEFAULT(INFO) << "\n " << variable << ": " << value##env; \ + } -void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info) { - // Set GPU device to be used - HIP_CALL_THROW(hipSetDevice(info_.device_id)); - t_ = migraphx::target(info.target_device.c_str()); +#define GET_ENV_BOOL(variable, value) \ + GET_ENV(variable, value, value = std::stoi(value##env) != 0) - // Quantization - fp16_enable_ = info.fp16_enable; +#define GET_ENV_STRING(variable, value) \ + GET_ENV(variable, value, value = value##env) +MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) + : IExecutionProvider{kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, info.device_id)}, + device_id_{info.device_id}, + fp16_enable_{info.fp16_enable}, +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && (HIP_VERSION_MINOR > 4 || (HIP_VERSION_MINOR == 4 && HIP_VERSION_PATCH >= 2))) + bf16_enable_{info.bf16_enable}, +#endif #if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4) - fp8_enable_ = info.fp8_enable; -#else - LOGS_DEFAULT(WARNING) << "MIGraphX: FP8 Quantization requires ROCm 6.4 or greater"; - fp8_enable_ = false; + fp8_enable_{info.fp8_enable}, #endif - int8_enable_ = info.int8_enable; - - if (int8_enable_ and fp8_enable_) { - LOGS_DEFAULT(FATAL) << "MIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; - } - - if (int8_enable_ xor fp8_enable_) { - int8_calibration_cache_name_ = info.int8_calibration_table_name; - int8_use_native_migraphx_calibration_table_ = info.int8_use_native_calibration_table; - } - - if (int8_enable_ or fp8_enable_) { - int8_calibration_cache_available_ = !info.int8_calibration_table_name.empty(); - } + int8_enable_{info.int8_enable}, + model_cache_path_{info.model_cache_dir}, + t_{info.target_device.c_str()}, + exhaustive_tune_{info.exhaustive_tune}, + metadef_id_generator_{ModelMetadefIdGenerator::Create()}, + external_alloc_{info.external_alloc}, + external_free_{info.external_free}, + external_empty_cache_{info.external_empty_cache} { + InitProviderOrtApi(); - // Load INT8 calibration table - std::unordered_map dynamic_range_map; - if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { - const std::string calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); - if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { - throw std::runtime_error("Session Failed to read INT8 calibration table " + calibration_cache_path); - } - } + // Set GPU device to be used and read device properties for feature usage. - // Save/load migraphx compiled models - save_compiled_model_ = info.save_compiled_model; - save_compiled_path_ = info.save_model_file; - load_compiled_model_ = info.load_compiled_model; - load_compiled_path_ = info.load_model_file; + HIP_CALL_THROW(hipSetDevice(device_id_)); + HIP_CALL_THROW(hipGetDeviceProperties(&device_prop_, device_id_)); - exhaustive_tune_ = info.exhaustive_tune; + // Overwrite initialized values with values from environment variables. - LOGS_DEFAULT(WARNING) << "[MIGraphX EP] MIGraphX provider Session Options:"; - print_migraphx_ep_flags(); -} + LOGS_DEFAULT(WARNING) << "[MIGraphX EP] MIGraphX ENV Override Variables Set:"; + GET_ENV_BOOL(migraphx_env_vars::kFP16Enable, fp16_enable_); +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && (HIP_VERSION_MINOR > 4 || (HIP_VERSION_MINOR == 4 && HIP_VERSION_PATCH >= 2))) + GET_ENV_BOOL(migraphx_env_vars::kBF16Enable, bf16_enable_); +#endif +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4) + GET_ENV_BOOL(migraphx_env_vars::kFP8Enable, fp8_enable_); +#endif + GET_ENV_BOOL(migraphx_env_vars::kINT8Enable, int8_enable_); + GET_ENV(migraphx_env_vars::kINT8CalibrationTableName, int8_calibration_cache_name_); + GET_ENV(migraphx_env_vars::kINT8UseNativeMIGraphXCalibrationTable, int8_use_native_migraphx_calibration_table_); + GET_ENV_STRING(migraphx_env_vars::kCachePath, calibration_cache_path_); + GET_ENV_STRING(migraphx_env_vars::kModelCachePath, model_cache_path_); + GET_ENV_BOOL(migraphx_env_vars::kDumpModelOps, dump_model_ops_); + GET_ENV_BOOL(migraphx_env_vars::kExhaustiveTune, exhaustive_tune_); + + // Verify configuration correctness and adjust accordingly. + +#if HIP_VERSION_MAJOR < 6 || (HIP_VERSION_MAJOR == 6 && (HIP_VERSION_MINOR < 4 || (HIP_VERSION_MINOR == 4 && HIP_VERSION_PATCH < 2))) + LOGS_DEFAULT(WARNING) << "MIGraphX: BF16 Quantization requires ROCm 6.4.2 or greater"; + bf16_enable_ = false; +#endif -void MIGraphXExecutionProvider::get_flags_from_env() { - LOGS_DEFAULT(WARNING) << "\n[MIGraphX EP] MIGraphX ENV Override Variables Set:"; - // whether fp16 is enable - const std::string fp16_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP16Enable); - if (!fp16_enable_env.empty()) { - fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP16_ENABLE: " << fp16_enable_; + if (bf16_enable_ && fp16_enable_) { + bf16_enable_ = false; + fp16_enable_ = false; + LOGS_DEFAULT(FATAL) << "MIGraphX: BF16 and FP16 Quantization Mutually exclusive. Ignoring both Quantization flags"; } - // whether fp8 quantization is enabled - const std::string fp8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP8Enable); - if (!fp8_enable_env.empty()) { -#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4) - fp8_enable_ = (std::stoi(fp8_enable_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP8_ENABLE: " << fp8_enable_; -#else - LOGS_DEFAULT(WARNING) << "MIGraphX: FP8 Quantization requires ROCm 6.4 or greater"; - fp8_enable = false; +#if HIP_VERSION_MAJOR < 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR < 4) + LOGS_DEFAULT(WARNING) << "MIGraphX: FP8 Quantization requires ROCm 6.4 or greater"; + fp8_enable_ = false; #endif - } - // whether int8 is enabled - const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8Enable); - if (!int8_enable_env.empty()) { - int8_enable_ = (std::stoi(int8_enable_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_ENABLE: " << int8_enable_; + if (int8_enable_ && fp8_enable_) { + LOGS_DEFAULT(FATAL) << "MIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; } - if (int8_enable_ and fp8_enable_) { - LOGS_DEFAULT(FATAL) << "\nMIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; + if (int8_enable_ ^ fp8_enable_) { + int8_calibration_table_name_ = + int8_calibration_cache_name_env.empty() ? info.int8_calibration_table_name : int8_calibration_cache_name_env; + int8_use_native_calibration_table_ = + int8_use_native_migraphx_calibration_table_env.empty() ? info.int8_use_native_calibration_table : std::stoi(int8_use_native_migraphx_calibration_table_env) != 0; } if (int8_enable_ || fp8_enable_) { - const std::string int8_calibration_cache_name_env = - onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8CalibrationTableName); - if (!int8_calibration_cache_name_env.empty()) { - int8_calibration_cache_name_ = int8_calibration_cache_name_env; - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CALIBRATION_TABLE_NAME: " << int8_calibration_cache_name_; - } - - const std::string cache_path = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kCachePath); - if (!cache_path.empty()) { - calibration_cache_path_ = cache_path; - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CACHE_PATH: " << calibration_cache_path_; - } - - const std::string int8_use_native_migraphx_calibration_table_env = - onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8UseNativeMIGraphXCalibrationTable); - if (!int8_use_native_migraphx_calibration_table_env.empty()) { - int8_use_native_migraphx_calibration_table_ = - (std::stoi(int8_use_native_migraphx_calibration_table_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE: " - << int8_use_native_migraphx_calibration_table_; - } - } - - if (int8_enable_ or fp8_enable_) { - int8_calibration_cache_available_ = !int8_calibration_cache_name_.empty(); + int8_calibration_cache_available_ = !info.int8_calibration_table_name.empty(); } // Load INT8 calibration table - std::unordered_map dynamic_range_map; if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { - const std::string calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); - if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { - throw std::runtime_error("ENV Failed to read calibration table " + calibration_cache_path); + std::unordered_map dynamic_range_map; + auto calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_table_name_); + if (!ReadDynamicRange(calibration_cache_path, int8_use_native_calibration_table_, dynamic_range_map)) { + throw std::runtime_error("Session Failed to read INT8 calibration table " + calibration_cache_path.string()); } } - // Save/load migraphx compiled models - const std::string save_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSaveCompiledModel); - if (!save_comp_model_env.empty()) { - save_compiled_model_ = (std::stoi(save_comp_model_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_SAVE_COMPILED_MODEL: " << save_compiled_model_; - } - - const std::string save_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSavedModelPath); - if (save_compiled_model_ && !save_model_path_env.empty()) { - save_compiled_path_ = save_model_path_env; - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_SAVE_COMPILED_PATH: " << save_compiled_path_; - } - - const std::string load_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadCompiledModel); - if (!load_comp_model_env.empty()) { - load_compiled_model_ = (std::stoi(load_comp_model_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_LOAD_COMPILED_MODEL: " << load_compiled_model_; - } - - const std::string load_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadModelPath); - if (load_compiled_model_ && !load_model_path_env.empty()) { - load_compiled_path_ = load_model_path_env; - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_LOAD_COMPILED_PATH: " << load_compiled_path_; - } - - // dump unsupported ops - const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::dumpModelOps); - if (!dump_model_ops_env.empty()) { - dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_DUMP_MODEL_OPS: " << dump_model_ops_; - } + // Print configured options for the session. - // Allow for exhaustive tune during compile - const std::string exhaustive_tune_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kExhaustiveTune); - if (!exhaustive_tune_env.empty()) { - exhaustive_tune_ = (std::stoi(exhaustive_tune_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_EXHAUSTIVE_TUNE_OPS: " << exhaustive_tune_; - } -} - -void MIGraphXExecutionProvider::print_migraphx_ep_flags() { - LOGS_DEFAULT(WARNING) << "\n device_id: " << info_.device_id - << "\n migraphx_fp16_enable: " << fp16_enable_ - << "\n migraphx_fp8_enable: " << fp8_enable_ - << "\n migraphx_int8_enable: " << int8_enable_ + LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider Session Options:" + << "\n " << migraphx_provider_option::kDeviceId << ": " << device_id_ + << "\n " << migraphx_provider_option::kFp16Enable << ": " << fp16_enable_ + << "\n " << migraphx_provider_option::kBf16Enable << ": " << bf16_enable_ + << "\n " << migraphx_provider_option::kFp8Enable << ": " << fp8_enable_ + << "\n " << migraphx_provider_option::kInt8Enable << ": " << int8_enable_ + << "\n " << migraphx_provider_option::kMemLimit << ": " << mem_limit_ + << "\n " << migraphx_provider_option::kArenaExtendStrategy << ": " << GetArenaExtendStrategyName(arena_extend_strategy_) << "\n dump_model_ops: " << dump_model_ops_ - << "\n exhaustive_tune: " << exhaustive_tune_ - << "\n migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_ + << "\n " << migraphx_provider_option::kExhaustiveTune << ": " << exhaustive_tune_ + << "\n " << migraphx_provider_option::kInt8CalibTable << ": " << int8_calibration_table_name_ << "\n int8_calibration_cache_available: " << int8_calibration_cache_available_ - << "\n use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_ - << "\n migraphx_save_compiled_model: " << save_compiled_model_ - << "\n migraphx_save_compiled_model_path: " << save_compiled_path_ - << "\n migraphx_load_compiled_model: " << load_compiled_model_ - << "\n migraphx_load_compiled_model_path: " << load_compiled_path_; -} - -AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, - size_t migx_mem_limit, - ArenaExtendStrategy arena_extend_strategy, - MIGraphXExecutionProviderExternalAllocatorInfo - external_allocator_info, - const OrtArenaCfg* default_memory_arena_cfg) { - if (external_allocator_info.UseExternalAllocator()) { - AllocatorCreationInfo default_memory_info( - [external_allocator_info](OrtDevice::DeviceId id) { - return std::make_unique(id, HIP, - external_allocator_info.alloc, - external_allocator_info.free, - external_allocator_info.empty_cache); - }, - device_id, - false); - - return CreateAllocator(default_memory_info); - } else { - AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId id) { - return std::make_unique(id, HIP); - }, - device_id, - true, - {default_memory_arena_cfg ? *default_memory_arena_cfg - : OrtArenaCfg(migx_mem_limit, static_cast(arena_extend_strategy), - -1, -1, -1, -1L)}, - // make it stream aware - true); - - // ROCM malloc/free is expensive so always use an arena - return CreateAllocator(default_memory_info); - } + << "\n " << migraphx_provider_option::kInt8UseNativeCalibTable << ": " << int8_use_native_calibration_table_ + << "\n " << migraphx_provider_option::kModelCacheDir << ": " << model_cache_path_; } std::vector MIGraphXExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId device_id) { return std::make_unique(device_id, onnxruntime::CUDA); }, - info_.device_id); + [](OrtDevice::DeviceId device_id) { + return std::make_unique(device_id, onnxruntime::CUDA); + }, + device_id_); AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { - return std::make_unique(device_id, onnxruntime::CUDA_PINNED); + return std::make_unique(device_id, CUDA_PINNED); }, - info_.device_id); + device_id_); return std::vector{CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)}; } @@ -354,6 +266,7 @@ static bool IsTypeSupported(const NodeArg* node_arg) { switch (type_proto->tensor_type().elem_type()) { case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FNUZ: @@ -384,6 +297,9 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: mgx_type = migraphx_shape_half_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + mgx_type = migraphx_shape_bf16_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: mgx_type = migraphx_shape_float_type; break; @@ -457,7 +373,7 @@ std::vector toVector(const ONNX_NAMESPACE::int64s& nums) { static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, const Node* node) { std::vector input_nodes; const auto& optype = node->OpType(); - if (optype == "ArgMax" or optype == "ArgMin") { + if (optype == "ArgMax" || optype == "ArgMin") { const auto& attributes = node->GetAttributes(); // we do not support select_last_index = 1 for now auto sli_attr = attributes.find("select_last_index"); @@ -475,7 +391,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return true; } - if ((input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) and + if ((input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) && (input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)) { return true; } @@ -503,7 +419,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co // storage order 1 (column major format) is not supported auto storage_order_attr = attributes.find("storage_order"); - if (storage_order_attr != attributes.end() and (*storage_order_attr).second.i() != 0) { + if (storage_order_attr != attributes.end() && (*storage_order_attr).second.i() != 0) { return true; } @@ -513,7 +429,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return true; } auto data_type = input_type->tensor_type().elem_type(); - if (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8 or + if (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8 || data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8) { return true; } @@ -524,7 +440,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return true; } - if ((input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) and + if ((input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) && (input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)) { return true; } @@ -580,7 +496,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co } return true; } - } else if (optype == "Resize" or optype == "Upsample") { + } else if (optype == "Resize" || optype == "Upsample") { const auto& attributes = node->GetAttributes(); auto ct_attr = attributes.find("coordinate_transformation_mode"); if (ct_attr != attributes.end()) { @@ -618,7 +534,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co } const auto& attributes = node->GetAttributes(); - if (attributes.count("starts") > 0 and attributes.count("ends") > 0) { + if (attributes.count("starts") > 0 && attributes.count("ends") > 0) { auto starts = toVector((*attributes.find("starts")).second.ints()); auto ends = toVector((*attributes.find("ends")).second.ints()); for (std::size_t i = 0; i < starts.size(); ++i) { @@ -656,7 +572,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { return true; } - } else if (optype == "Unsqueeze" or optype == "Squeeze") { + } else if (optype == "Unsqueeze" || optype == "Squeeze") { const auto& args = node->InputDefs(); if (args.size() == 2) { if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { @@ -685,9 +601,9 @@ void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::v if (args.size() == 2) { std::vector node_inputs; if (canEvalNodeArgument(graph_viewer, node, {1}, node_inputs)) { - return (not std::all_of(node_inputs.begin(), node_inputs.end(), [&](auto index) { - return std::find(git.begin(), git.end(), index) != git.end(); - })); + return !std::all_of(node_inputs.begin(), node_inputs.end(), [&](auto i) { + return std::find(git.begin(), git.end(), i) != git.end(); + }); } else { return true; } @@ -857,12 +773,14 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st erased.insert(output); } // Only when output is neither in input list nor erased list, add the output to output list - else if (erased.find(output) == erased.end()) { - if (std::find(graph_output_names.begin(), - graph_output_names.end(), output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; + else { + if (erased.find(output) == erased.end()) { + if (std::find(graph_output_names.begin(), + graph_output_names.end(), output->Name()) != graph_output_names.end()) { + graph_outputs_to_add[output] = output_order; + } + fused_outputs[output] = output_order++; } - fused_outputs[output] = output_order++; } } } @@ -944,6 +862,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "Atan", "Atanh", "ATen", + "Attention", "AveragePool", "BatchNormalization", "BiasGelu", @@ -986,6 +905,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "Greater", "GreaterOrEqual", "GroupNormalization", + "GroupNorm", "GroupQueryAttention", "HardSigmoid", "HardSwish", @@ -1017,6 +937,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "MultiHeadAttention", "Neg", "NegativeLogLikelihoodLoss", + "NhwcConv", "NonMaxSuppression", "NonZero", "Not", @@ -1053,6 +974,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "ReverseSequence", "RNN", "Roialign", + "RotaryEmbedding", "Round", "Scatter", "ScatterElements", @@ -1243,29 +1165,25 @@ bool get_input_output_names(const GraphViewer& graph, // Attempt to load a model and catch any exceptions on load fail. // Useful to default to EP to trigger the compile if file doesn't exist or loading fails. -bool load_precompiled_model(migraphx::program& prog, bool load_enable, std::string path) { - try { - if (load_enable) { - LOGS_DEFAULT(WARNING) << "Attempting to load model at:" << path; - prog = migraphx::load(path.c_str()); - LOGS_DEFAULT(WARNING) << "load model : Success"; - return true; - } else { - return false; - } - } catch (...) { - return false; +bool load_precompiled_model(migraphx::program& prog, const std::filesystem::path& path) try { + if (!path.empty() && exists(path)) { + LOGS_DEFAULT(VERBOSE) << "Attempting to load model at:" << path.string(); + prog = migraphx::load(path.string().c_str()); + LOGS_DEFAULT(VERBOSE) << "load model : Success"; + return true; } return false; +} catch (...) { + return false; } -void save_compiled_model(migraphx::program& prog, bool save_enable, std::string out_path) { - if (save_enable) { - LOGS_DEFAULT(WARNING) << "Model Save at " << out_path << ": Begin"; +void save_compiled_model(const migraphx::program& prog, const std::filesystem::path& path) { + if (!path.empty()) { + LOGS_DEFAULT(VERBOSE) << "Model Save at " << path.string() << ": Begin"; migraphx::file_options fo; fo.set_file_format("msgpack"); - migraphx::save(prog, out_path.c_str(), fo); - LOGS_DEFAULT(WARNING) << "Model Save: Complete"; + save(prog, path.string().c_str(), fo); + LOGS_DEFAULT(VERBOSE) << "Model Save: Complete"; } } @@ -1275,12 +1193,13 @@ void calibrate_and_quantize(migraphx::program& prog, const migraphx::target& t, const migraphx::program_parameters quant_params, bool fp16_enable, + bool bf16_enable, bool int8_enable, bool fp8_enable, bool int8_calibration_cache_available, std::unordered_map& dynamic_range_map) { // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if ((int8_enable xor fp8_enable) && int8_calibration_cache_available) { + if ((int8_enable ^ fp8_enable) && int8_calibration_cache_available) { LOGS_DEFAULT(WARNING) << "Quantizing input program"; auto param_shapes = prog.get_parameter_shapes(); @@ -1317,6 +1236,14 @@ void calibrate_and_quantize(migraphx::program& prog, migraphx::quantize_fp16(prog); LOGS_DEFAULT(WARNING) << "Quantizing fp16: Complete"; } + +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4 && HIP_VERSION_PATCH >= 2) + if (bf16_enable) { + LOGS_DEFAULT(WARNING) << "Quantizing input program to bf16"; + migraphx::quantize_bf16(prog); + LOGS_DEFAULT(WARNING) << "Quantizing bf16: Complete"; + } +#endif } void compile_program(migraphx::program& prog, @@ -1330,6 +1257,27 @@ void compile_program(migraphx::program& prog, LOGS_DEFAULT(WARNING) << "Model Compile: Complete"; } +std::string to_hex(const uint64_t v) { + std::array s{}; + auto [ptr, _] = std::to_chars(s.data(), s.data() + s.size(), v, 16); + return std::string{s.data(), ptr}; +} + +template +std::string make_hash(T v) { + std::array temp{}; + MurmurHash3::x86_128(v.data(), gsl::narrow_cast(v.size()), temp[0], temp.data()); + return to_hex(temp[0] | static_cast(temp[1]) << 32); +} + +template <> +std::string make_hash(const char* v) { + return make_hash(std::string_view{v}); +} + +constexpr std::uint64_t MIGraphX_Version = + ((MIGRAPHX_VERSION_MAJOR << 16) | (MIGRAPHX_VERSION_MINOR << 8) | MIGRAPHX_VERSION_PATCH); + Status MIGraphXExecutionProvider::Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) { migraphx::onnx_options options; @@ -1337,6 +1285,33 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& for (const auto& fused_node_graph : fused_nodes) { const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; const Node& fused_node = fused_node_graph.fused_node; + + std::filesystem::path model_cache_file; + auto mxr_filename_prefix = to_hex(MIGraphX_Version) + "-" + GenerateGraphId(graph_body_viewer) + "-" + make_hash(std::string_view(device_prop_.gcnArchName)) + "-"; + + // Get model input names (only first layer) + const Graph* cur_graph = &graph_body_viewer.GetGraph(); + while (cur_graph->IsSubgraph()) { + cur_graph = cur_graph->ParentGraph(); + } + const Graph& main_graph = *cur_graph; + const auto& input_tensor = main_graph.GetInputs(); + for (auto i : input_tensor) { + session_input_names.insert(i->Name()); + } + + // empty cache path means the MXR caching is disabled - always compile + if (!model_cache_path_.empty()) { + std::vector input_shapes; + for (std::size_t i = 0; i < session_input_names.size(); ++i) { + auto tensor_shape = input_tensor[i]->Shape(); + for (int j = 1; j < tensor_shape->dim_size(); ++j) { + input_shapes.push_back(tensor_shape->dim(j).dim_value()); + } + } + model_cache_file = model_cache_path_ / (mxr_filename_prefix + make_hash(input_shapes) + ".mxr"); + } + // map parameter input name to index std::unordered_map input_name_index; const auto& input_defs = fused_node.InputDefs(); @@ -1367,15 +1342,20 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::program prog; if (!no_input_shape) { - if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { - LOGS_DEFAULT(INFO) << "No input shapes detected quantizing model"; + if (!load_precompiled_model(prog, model_cache_file)) { + LOGS_DEFAULT(VERBOSE) << "No input shapes detected quantizing model"; +#ifndef ENABLE_TRAINING_CORE +#ifdef HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH + options.set_external_data_path(model_path_.parent_path().string()); +#endif +#endif prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); migraphx::program_parameters quant_params; - calibrate_and_quantize(prog, t_, quant_params, fp16_enable_, int8_enable_, + calibrate_and_quantize(prog, t_, quant_params, fp16_enable_, bf16_enable_, int8_enable_, fp8_enable_, int8_calibration_cache_available_, dynamic_range_map_); compile_program(prog, t_, exhaustive_tune_); - save_compiled_model(prog, save_compiled_model_, save_compiled_path_); + save_compiled_model(prog, model_cache_file); } auto prog_output_shapes = prog.get_output_shapes(); @@ -1396,10 +1376,9 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::unique_ptr p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, - map_no_input_shape_[context->node_name], fp16_enable_, fp8_enable_, int8_enable_, + map_no_input_shape_[context->node_name], fp16_enable_, bf16_enable_, fp8_enable_, int8_enable_, int8_calibration_cache_available_, dynamic_range_map_, - save_compiled_model_, save_compiled_path_, - load_compiled_model_, load_compiled_path_, dump_model_ops_}; + model_cache_path_.string(), dump_model_ops_}; *state = p.release(); return 0; }; @@ -1409,7 +1388,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& delete static_cast(state); }; - compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) { + compute_info.compute_func = [this, mxr_filename_prefix](FunctionState state, const OrtApi* api, OrtKernelContext* context) { Ort::KernelContext ctx(context); MIGraphXFuncState* mgx_state = reinterpret_cast(state); @@ -1421,6 +1400,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::onnx_options& cmp_options = mgx_state->options; bool& no_input_shape = mgx_state->no_input_shape; bool fp16_enable = mgx_state->fp16_enable; + bool bf16_enable = mgx_state->bf16_enable; bool fp8_enable = mgx_state->fp8_enable; bool int8_enable = mgx_state->int8_enable; bool int8_calibration_cache_available = mgx_state->int8_calibration_cache_available; @@ -1429,8 +1409,10 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // from input data bool input_shape_match = true; migraphx::program_parameter_shapes param_shapes; + std::vector input_shapes; + if (no_input_shape) { - LOGS_DEFAULT(INFO) << "Missing input shape setting input parameters again"; + LOGS_DEFAULT(VERBOSE) << "Missing input shape setting input parameters again"; for (auto& it : map_input_name_index) { auto& name = it.first; auto& index = it.second; @@ -1442,7 +1424,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& input_shape_match = false; } } else { - LOGS_DEFAULT(INFO) << "Assigning inputs, and parameters from compiled model"; + LOGS_DEFAULT(VERBOSE) << "Assigning inputs, and parameters from compiled model"; param_shapes = prog.get_parameter_shapes(); auto prog_output_shapes = prog.get_output_shapes(); @@ -1459,8 +1441,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& auto mgx_s = param_shapes[name]; auto mgx_lens = mgx_s.lengths(); auto mgx_strides = mgx_s.strides(); - if (mgx_lens.size() == 1 and mgx_lens[0] == 1 and - mgx_strides.size() == 1 and mgx_strides[0] == 0) { + if (mgx_lens.size() == 1 && mgx_lens[0] == 1 && + mgx_strides.size() == 1 && mgx_strides[0] == 0) { mgx_lens.clear(); } @@ -1468,6 +1450,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& cmp_options.set_input_parameter_shape(name, ort_lens); input_shape_match = false; } + input_shapes.insert(input_shapes.end(), tensor_shape.begin(), tensor_shape.end()); } } } @@ -1476,20 +1459,25 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // input shapes are different, needs to re-parse onnx and // re-compile the program if (!input_shape_match) { - if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { - LOGS_DEFAULT(VERBOSE) << "Input shape mismatch detected. Recompiling" << std::endl; + std::filesystem::path model_cache_file; + // empty cache path means the MXR caching is disabled - always compile + if (!model_cache_path_.empty()) { + model_cache_file = mgx_state->model_cache_dir / (mxr_filename_prefix + make_hash(input_shapes) + ".mxr"); + } + if (!load_precompiled_model(prog, model_cache_file)) { + LOGS_DEFAULT(VERBOSE) << "Input shape mismatch detected. Recompiling"; #ifndef ENABLE_TRAINING_CORE -#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 2) - cmp_options.set_external_data_path(model_path_.has_parent_path() ? model_path_.parent_path().string() : std::filesystem::current_path().string()); +#ifdef HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH + cmp_options.set_external_data_path(model_path_.parent_path().string()); #endif #endif prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); migraphx::program_parameters quant_params; - if ((int8_enable xor fp8_enable) and int8_calibration_cache_available) { - auto param_shapes = prog.get_parameter_shapes(); + if ((int8_enable ^ fp8_enable) && int8_calibration_cache_available) { + auto local_param_shapes = prog.get_parameter_shapes(); // Add input parameter data and the values they're set to - for (auto&& name : param_shapes.names()) { + for (auto&& name : local_param_shapes.names()) { if (map_input_name_index.count(name) > 0) { auto input_tensor = ctx.GetInput(map_input_name_index[name]); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); @@ -1498,19 +1486,19 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx_shape_datatype_t mgx_type; getMIGraphXType(tensor_type, mgx_type); - auto mgx_s = param_shapes[name]; + auto mgx_s = local_param_shapes[name]; if (mgx_type != mgx_s.type()) { LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; } - quant_params.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); + quant_params.add(name, migraphx::argument(local_param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); } } } - calibrate_and_quantize(prog, t, quant_params, fp16_enable, int8_enable, + calibrate_and_quantize(prog, t, quant_params, fp16_enable, bf16_enable, int8_enable, fp8_enable, int8_calibration_cache_available, map_dynamic_range); compile_program(prog, t, exhaustive_tune_); - save_compiled_model(prog, mgx_state->save_compiled_mode, mgx_state->save_compiled_path); + save_compiled_model(prog, model_cache_file); } mgx_state->prog = prog; @@ -1524,7 +1512,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (param_shapes.size() > 0) { for (auto&& name : param_shapes.names()) { if (map_input_name_index.count(name) > 0) { - LOGS_DEFAULT(INFO) << "Setting parameters for:" << name; + LOGS_DEFAULT(VERBOSE) << "Setting parameters for:" << name; auto input_tensor = ctx.GetInput(map_input_name_index[name]); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shape = tensor_info.GetShape(); @@ -1538,21 +1526,21 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; } - LOGS_DEFAULT(INFO) << "Writing Raw tensor data "; + LOGS_DEFAULT(VERBOSE) << "Writing Raw tensor data "; m.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); } - // It is a output argument + // It is an output argument else { - auto compute_output_index = [](const std::string& name) -> int { - std::string out_name_prefix = "#output_"; - auto pos = name.find(out_name_prefix); - if (pos == std::string::npos) { + auto compute_output_index = [](const std::string_view sv) -> int { + constexpr std::string_view out_name_prefix = "#output_"; + const auto pos = sv.find(out_name_prefix); + if (pos == std::string_view::npos) { return -1; } - std::string index_str = name.substr(pos + out_name_prefix.length()); - return std::stoi(index_str); + const auto index_str = sv.substr(pos + out_name_prefix.length()); + return ToInteger(Trim(index_str, std::isdigit)); }; int output_index = compute_output_index(name); @@ -1599,7 +1587,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& static_cast(rocm_stream))); } } - }; + } return Status::OK(); }; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index aecccdd54d697..99f790b9f9f7a 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -3,33 +3,37 @@ #pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include #include "core/framework/arena_extend_strategy.h" #include "core/framework/execution_provider.h" -#include +#include "core/framework/provider_options_utils.h" #include "core/providers/migraphx/migraphx_execution_provider_info.h" #include "core/providers/migraphx/migraphx_call.h" -#include -#include -#include +using namespace std::literals::string_view_literals; namespace onnxruntime { namespace migraphx_env_vars { -static const char kFP16Enable[] = "ORT_MIGRAPHX_FP16_ENABLE"; -static const char kFP8Enable[] = "ORT_MIGRAPHX_FP8_ENABLE"; -static const char kINT8Enable[] = "ORT_MIGRAPHX_INT8_ENABLE"; -static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; -static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"; -static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH"; -static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"; -static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL"; -static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILED_PATH"; -static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL"; -static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILED_PATH"; -static const char kExhaustiveTune[] = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"; - -}; // namespace migraphx_env_vars +constexpr auto kFP16Enable = "ORT_MIGRAPHX_FP16_ENABLE"sv; +constexpr auto kBF16Enable = "ORT_MIGRAPHX_BF16_ENABLE"sv; +constexpr auto kFP8Enable = "ORT_MIGRAPHX_FP8_ENABLE"sv; +constexpr auto kINT8Enable = "ORT_MIGRAPHX_INT8_ENABLE"sv; +constexpr auto kDumpModelOps = "ORT_MIGRAPHX_DUMP_MODEL_OPS"sv; +constexpr auto kINT8CalibrationTableName = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"sv; +constexpr auto kCachePath = "ORT_MIGRAPHX_CACHE_PATH"sv; +constexpr auto kINT8UseNativeMIGraphXCalibrationTable = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"sv; +constexpr auto kExhaustiveTune = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"sv; +constexpr auto kModelCachePath = "ORT_MIGRAPHX_MODEL_CACHE_PATH"sv; +} // namespace migraphx_env_vars // Information to construct kernel function state. struct MIGraphXFuncState { @@ -44,14 +48,12 @@ struct MIGraphXFuncState { std::mutex* mgx_mu_ptr = nullptr; bool no_input_shape = false; bool fp16_enable = false; + bool bf16_enable = false; bool fp8_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; std::unordered_map dynamic_range_map; - bool save_compiled_mode = false; - std::string save_compiled_path; - bool load_compiled_mode = false; - std::string load_compiled_path; + std::filesystem::path model_cache_dir; bool dump_model_ops = false; bool exhaustive_tune = false; }; @@ -60,11 +62,7 @@ struct MIGraphXFuncState { class MIGraphXExecutionProvider : public IExecutionProvider { public: explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info); - ~MIGraphXExecutionProvider(); - - void get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info); - void get_flags_from_env(); - void print_migraphx_ep_flags(); + ~MIGraphXExecutionProvider() override = default; Status Sync() const override; @@ -81,42 +79,55 @@ class MIGraphXExecutionProvider : public IExecutionProvider { common::Status Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) override; - virtual std::shared_ptr GetKernelRegistry() const override; + std::shared_ptr GetKernelRegistry() const override; std::unique_ptr GetDataTransfer() const override; - static AllocatorPtr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t migx_mem_limit, ArenaExtendStrategy arena_extend_strategy, - MIGraphXExecutionProviderExternalAllocatorInfo external_alloc_info, const OrtArenaCfg* arena_cfg); - std::unique_ptr GetSubGraph(const std::vector& graph_nodes_index, const GraphViewer& graph) const; void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; std::vector CreatePreferredAllocators() override; - int GetDeviceId() const override { return info_.device_id; } + int GetDeviceId() const override { return device_id_; } ProviderOptions GetProviderOptions() const override { - return MIGraphXExecutionProviderInfo::ToProviderOptions(info_); + return { + {std::string{migraphx_provider_option::kDeviceId}, MakeStringWithClassicLocale(device_id_)}, + {std::string{migraphx_provider_option::kFp16Enable}, MakeStringWithClassicLocale(fp16_enable_)}, + {std::string{migraphx_provider_option::kBf16Enable}, MakeStringWithClassicLocale(bf16_enable_)}, + {std::string{migraphx_provider_option::kFp8Enable}, MakeStringWithClassicLocale(fp8_enable_)}, + {std::string{migraphx_provider_option::kInt8Enable}, MakeStringWithClassicLocale(int8_enable_)}, + {std::string{migraphx_provider_option::kInt8CalibTable}, MakeStringWithClassicLocale(int8_calibration_table_name_)}, + {std::string{migraphx_provider_option::kInt8UseNativeCalibTable}, MakeStringWithClassicLocale(int8_use_native_calibration_table_)}, + {std::string{migraphx_provider_option::kExhaustiveTune}, MakeStringWithClassicLocale(exhaustive_tune_)}, + {std::string{migraphx_provider_option::kMemLimit}, MakeStringWithClassicLocale(mem_limit_)}, + {std::string{migraphx_provider_option::kArenaExtendStrategy}, EnumToName(arena_extend_strategy_mapping, arena_extend_strategy_)}, + {std::string{migraphx_provider_option::kGpuExternalAlloc}, MakeStringWithClassicLocale(external_alloc_)}, + {std::string{migraphx_provider_option::kGpuExternalFree}, MakeStringWithClassicLocale(external_free_)}, + {std::string{migraphx_provider_option::kGpuExternalEmptyCache}, MakeStringWithClassicLocale(external_empty_cache_)}, + {std::string{migraphx_provider_option::kModelCacheDir}, MakeStringWithClassicLocale(model_cache_path_)}}; } private: - MIGraphXExecutionProviderInfo info_; + OrtDevice::DeviceId device_id_{0}; bool fp16_enable_ = false; + bool bf16_enable_ = false; bool fp8_enable_ = false; bool int8_enable_ = false; - std::string int8_calibration_cache_name_; + std::string int8_calibration_table_name_; bool int8_calibration_cache_available_ = false; - bool int8_use_native_migraphx_calibration_table_ = false; - std::string calibration_cache_path_; + bool int8_use_native_calibration_table_ = false; + std::filesystem::path calibration_cache_path_{}; std::unordered_map dynamic_range_map_; - bool save_compiled_model_ = false; - std::string save_compiled_path_; - bool load_compiled_model_ = false; - std::string load_compiled_path_; + std::filesystem::path model_cache_path_{}; + std::set session_input_names; bool dump_model_ops_ = false; migraphx::target t_; std::mutex mgx_mu_; hipStream_t stream_ = nullptr; + hipDeviceProp_t device_prop_{}; bool exhaustive_tune_ = false; - mutable std::filesystem::path model_path_; + mutable std::filesystem::path model_path_{}; + size_t mem_limit_{std::numeric_limits::max()}; + ArenaExtendStrategy arena_extend_strategy_{ArenaExtendStrategy::kNextPowerOfTwo}; std::unordered_map map_progs_; std::unordered_map map_onnx_string_; @@ -125,6 +136,9 @@ class MIGraphXExecutionProvider : public IExecutionProvider { AllocatorPtr allocator_; std::unique_ptr metadef_id_generator_; + void* external_alloc_{nullptr}; + void* external_free_{nullptr}; + void* external_empty_cache_{nullptr}; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index cf21d791cfe6b..33ef366eb18e5 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -1,14 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/providers/shared_library/provider_api.h" #include "core/providers/migraphx/migraphx_execution_provider_info.h" #include "core/common/make_string.h" #include "core/common/parse_string.h" -#include "core/framework/provider_options_utils.h" -#include "migraphx_inc.h" -#include "migraphx_call.h" +#include "core/providers/migraphx/migraphx_inc.h" +#include "core/providers/migraphx/migraphx_call.h" namespace onnxruntime { @@ -17,118 +18,90 @@ const EnumNameMapping arena_extend_strategy_mapping{ {ArenaExtendStrategy::kSameAsRequested, "kSameAsRequested"}, }; -namespace migraphx { -namespace provider_option_names { -constexpr const char* kDeviceId = "device_id"; -constexpr const char* kFp16Enable = "trt_fp16_enable"; -constexpr const char* kFp8Enable = "migx_fp8_enable"; -constexpr const char* kInt8Enable = "migx_int8_enable"; -constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name"; -constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table"; -constexpr const char* kSaveCompiledModel = "migx_save_compiled_model"; -constexpr const char* kSaveModelPath = "migx_save_model_name"; -constexpr const char* kLoadCompiledModel = "migx_load_compiled_model"; -constexpr const char* kLoadModelPath = "migx_load_model_name"; -constexpr const char* kExhaustiveTune = "migx_exhaustive_tune"; -constexpr const char* kMemLimit = "migx_mem_limit"; -constexpr const char* kArenaExtendStrategy = "migx_arena_extend_strategy"; -constexpr const char* kGpuExternalAlloc = "migx_external_alloc"; -constexpr const char* kGpuExternalFree = "migx_external_free"; -constexpr const char* kGpuExternalEmptyCache = "migx_external_empty_cache"; - -} // namespace provider_option_names -} // namespace migraphx - -MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { - MIGraphXExecutionProviderInfo info{}; - void* alloc = nullptr; - void* free = nullptr; - void* empty_cache = nullptr; +MIGraphXExecutionProviderInfo::MIGraphXExecutionProviderInfo(const ProviderOptions& options) { ORT_THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( - migraphx::provider_option_names::kDeviceId, - [&info](const std::string& value_str) -> Status { - ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id)); + migraphx_provider_option::kDeviceId, + [this](const std::string& value_str) -> Status { + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, device_id)); int num_devices{}; ORT_RETURN_IF_ERROR(HIP_CALL(hipGetDeviceCount(&num_devices))); ORT_RETURN_IF_NOT( - 0 <= info.device_id && info.device_id < num_devices, - "Invalid device ID: ", info.device_id, + 0 <= device_id && device_id < num_devices, + "Invalid device ID: ", device_id, ", must be between 0 (inclusive) and ", num_devices, " (exclusive)."); return Status::OK(); }) .AddValueParser( - migraphx::provider_option_names::kGpuExternalAlloc, - [&alloc](const std::string& value_str) -> Status { - size_t address; + migraphx_provider_option::kGpuExternalAlloc, + [this](const std::string& value_str) -> Status { + std::uintptr_t address; ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); - alloc = reinterpret_cast(address); + external_alloc = reinterpret_cast(address); return Status::OK(); }) .AddValueParser( - migraphx::provider_option_names::kGpuExternalFree, - [&free](const std::string& value_str) -> Status { - size_t address; + migraphx_provider_option::kGpuExternalFree, + [this](const std::string& value_str) -> Status { + std::uintptr_t address; ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); - free = reinterpret_cast(address); + external_free = reinterpret_cast(address); return Status::OK(); }) .AddValueParser( - migraphx::provider_option_names::kGpuExternalEmptyCache, - [&empty_cache](const std::string& value_str) -> Status { - size_t address; + migraphx_provider_option::kGpuExternalEmptyCache, + [this](const std::string& value_str) -> Status { + std::uintptr_t address; ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); - empty_cache = reinterpret_cast(address); + external_empty_cache = reinterpret_cast(address); + return Status::OK(); + }) + .AddValueParser( + migraphx_provider_option::kModelCacheDir, + [this](const std::string& value_str) -> Status { + model_cache_dir = ToPathString(value_str); return Status::OK(); }) - .AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable) - .AddAssignmentToReference(migraphx::provider_option_names::kFp8Enable, info.fp8_enable) - .AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable) - .AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model) - .AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model) - .AddAssignmentToReference(migraphx::provider_option_names::kExhaustiveTune, info.exhaustive_tune) - .AddAssignmentToReference(migraphx::provider_option_names::kMemLimit, info.mem_limit) - .AddAssignmentToEnumReference(migraphx::provider_option_names::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy) + .AddAssignmentToReference(migraphx_provider_option::kFp16Enable, fp16_enable) + .AddAssignmentToReference(migraphx_provider_option::kBf16Enable, bf16_enable) + .AddAssignmentToReference(migraphx_provider_option::kFp8Enable, fp8_enable) + .AddAssignmentToReference(migraphx_provider_option::kInt8Enable, int8_enable) + .AddAssignmentToReference(migraphx_provider_option::kInt8UseNativeCalibTable, int8_use_native_calibration_table) + .AddAssignmentToReference(migraphx_provider_option::kInt8CalibTable, int8_calibration_table_name) + .AddAssignmentToReference(migraphx_provider_option::kExhaustiveTune, exhaustive_tune) + .AddAssignmentToReference(migraphx_provider_option::kMemLimit, mem_limit) + .AddAssignmentToEnumReference(migraphx_provider_option::kArenaExtendStrategy, arena_extend_strategy_mapping, arena_extend_strategy) .Parse(options)); - - MIGraphXExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache}; - info.external_allocator_info = alloc_info; - - return info; } -ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXExecutionProviderInfo& info) { - const ProviderOptions options{ - {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, - {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, - {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.fp8_enable)}, - {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, - {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)}, - {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)}, - {migraphx::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.mem_limit)}, - {migraphx::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.alloc))}, - {migraphx::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.free))}, - {migraphx::provider_option_names::kGpuExternalEmptyCache, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.empty_cache))}, - {migraphx::provider_option_names::kArenaExtendStrategy, - EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, - {migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.exhaustive_tune)}, - }; - return options; +MIGraphXExecutionProviderInfo::MIGraphXExecutionProviderInfo(const OrtMIGraphXProviderOptions& options) noexcept + : device_id{static_cast(options.device_id)}, + fp16_enable{options.migraphx_fp16_enable != 0}, + fp8_enable{options.migraphx_fp8_enable != 0}, + int8_enable{options.migraphx_int8_enable != 0}, + exhaustive_tune{options.migraphx_exhaustive_tune != 0}, + mem_limit{options.migraphx_mem_limit}, + arena_extend_strategy{options.migraphx_arena_extend_strategy} { } -ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGraphXProviderOptions& info) { - const ProviderOptions options{ - {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, - {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, - {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.migraphx_fp8_enable)}, - {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, - {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)}, - {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)}, - {migraphx::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.migraphx_mem_limit)}, - {migraphx::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast(info.migraphx_arena_extend_strategy))}, - {migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.migraphx_exhaustive_tune)}, +ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions() const { + return { + {std::string{migraphx_provider_option::kDeviceId}, MakeStringWithClassicLocale(device_id)}, + {std::string{migraphx_provider_option::kFp16Enable}, MakeStringWithClassicLocale(fp16_enable)}, + {std::string{migraphx_provider_option::kBf16Enable}, MakeStringWithClassicLocale(bf16_enable)}, + {std::string{migraphx_provider_option::kFp8Enable}, MakeStringWithClassicLocale(fp8_enable)}, + {std::string{migraphx_provider_option::kInt8Enable}, MakeStringWithClassicLocale(int8_enable)}, + {std::string{migraphx_provider_option::kInt8CalibTable}, MakeStringWithClassicLocale(int8_calibration_table_name)}, + {std::string{migraphx_provider_option::kInt8UseNativeCalibTable}, MakeStringWithClassicLocale(int8_use_native_calibration_table)}, + {std::string{migraphx_provider_option::kMemLimit}, MakeStringWithClassicLocale(mem_limit)}, + {std::string{migraphx_provider_option::kArenaExtendStrategy}, EnumToName(arena_extend_strategy_mapping, arena_extend_strategy)}, + {std::string{migraphx_provider_option::kExhaustiveTune}, MakeStringWithClassicLocale(exhaustive_tune)}, + {std::string{migraphx_provider_option::kGpuExternalAlloc}, MakeStringWithClassicLocale(external_alloc)}, + {std::string{migraphx_provider_option::kGpuExternalFree}, MakeStringWithClassicLocale(external_free)}, + {std::string{migraphx_provider_option::kGpuExternalEmptyCache}, MakeStringWithClassicLocale(external_empty_cache)}, + {std::string{migraphx_provider_option::kModelCacheDir}, MakeStringWithClassicLocale(model_cache_dir)}, }; - return options; } + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index a598052c5f025..414254aaa2629 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -3,70 +3,79 @@ #pragma once +#include #include #include +#include #include "core/framework/ortdevice.h" #include "core/common/hash_combine.h" #include "core/framework/arena_extend_strategy.h" #include "core/framework/provider_options.h" +#include "core/framework/provider_options_utils.h" #include "core/session/onnxruntime_c_api.h" -namespace onnxruntime { - -// Information needed to construct MIGraphX execution providers. -struct MIGraphXExecutionProviderExternalAllocatorInfo { - void* alloc{nullptr}; - void* free{nullptr}; - void* empty_cache{nullptr}; - - MIGraphXExecutionProviderExternalAllocatorInfo() { - alloc = nullptr; - free = nullptr; - empty_cache = nullptr; - } +using namespace std::literals::string_view_literals; - MIGraphXExecutionProviderExternalAllocatorInfo(void* a, void* f, void* e) { - alloc = a; - free = f; - empty_cache = e; - } +namespace onnxruntime { - bool UseExternalAllocator() const { - return (alloc != nullptr) && (free != nullptr); - } -}; +namespace migraphx_provider_option { +constexpr auto kDeviceId = "device_id"sv; +constexpr auto kFp16Enable = "migraphx_fp16_enable"sv; +constexpr auto kBf16Enable = "migraphx_bf16_enable"sv; +constexpr auto kFp8Enable = "migraphx_fp8_enable"sv; +constexpr auto kInt8Enable = "migraphx_int8_enable"sv; +constexpr auto kInt8CalibTable = "migraphx_int8_calibration_table_name"sv; +constexpr auto kInt8UseNativeCalibTable = "migraphx_int8_use_native_calibration_table"sv; +constexpr auto kExhaustiveTune = "migraphx_exhaustive_tune"sv; +constexpr auto kMemLimit = "migraphx_mem_limit"sv; +constexpr auto kArenaExtendStrategy = "migraphx_arena_extend_strategy"sv; +constexpr auto kGpuExternalAlloc = "migraphx_external_alloc"sv; +constexpr auto kGpuExternalFree = "migraphx_external_free"sv; +constexpr auto kGpuExternalEmptyCache = "migraphx_external_empty_cache"sv; +constexpr auto kModelCacheDir = "migraphx_model_cache_dir"sv; +} // namespace migraphx_provider_option + +extern const EnumNameMapping arena_extend_strategy_mapping; // Information needed to construct trt execution providers. struct MIGraphXExecutionProviderInfo { - std::string target_device; + std::string target_device{"gpu"}; OrtDevice::DeviceId device_id{0}; bool fp16_enable{false}; + bool bf16_enable{false}; bool fp8_enable{false}; bool int8_enable{false}; - std::string int8_calibration_table_name{""}; + std::string int8_calibration_table_name{}; bool int8_use_native_calibration_table{false}; - bool save_compiled_model{true}; - std::string save_model_file{"./compiled_model.mxr"}; - bool load_compiled_model{true}; - std::string load_model_file{"./compiled_model.mxr"}; + std::filesystem::path model_cache_dir{}; bool exhaustive_tune{false}; - size_t mem_limit{std::numeric_limits::max()}; // Will be over-ridden by contents of `default_memory_arena_cfg` (if specified) - ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; // Will be over-ridden by contents of `default_memory_arena_cfg` (if specified) + size_t mem_limit{std::numeric_limits::max()}; + ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; OrtArenaCfg* default_memory_arena_cfg{nullptr}; - MIGraphXExecutionProviderExternalAllocatorInfo external_allocator_info{}; - static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); - static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info); - static ProviderOptions ToProviderOptions(const OrtMIGraphXProviderOptions& info); + void* external_alloc{nullptr}; + void* external_free{nullptr}; + void* external_empty_cache{nullptr}; + + bool UseExternalAlloc() const { + return external_alloc != nullptr && external_free != nullptr; + } + + MIGraphXExecutionProviderInfo() = default; + + explicit MIGraphXExecutionProviderInfo(const ProviderOptions& options); + explicit MIGraphXExecutionProviderInfo(const OrtMIGraphXProviderOptions& options) noexcept; + ProviderOptions ToProviderOptions() const; }; + } // namespace onnxruntime template <> struct std::hash<::onnxruntime::MIGraphXExecutionProviderInfo> { - size_t operator()(const ::onnxruntime::MIGraphXExecutionProviderInfo& info) const { + size_t operator()(const ::onnxruntime::MIGraphXExecutionProviderInfo& info) const noexcept { size_t value{0xbc9f1d34}; // seed // Bits: device_id (16), arena_extend_strategy (reserved 2), boolean options (1 each) @@ -75,17 +84,21 @@ struct std::hash<::onnxruntime::MIGraphXExecutionProviderInfo> { (static_cast(info.fp16_enable) << 18) ^ (static_cast(info.int8_enable) << 19) ^ (static_cast(info.int8_use_native_calibration_table) << 20) ^ - (static_cast(info.save_compiled_model) << 21) ^ - (static_cast(info.load_compiled_model) << 22) ^ - (static_cast(info.exhaustive_tune) << 23); + (static_cast(info.exhaustive_tune) << 21) ^ + (static_cast(info.bf16_enable) << 22); + onnxruntime::HashCombine(data, value); + onnxruntime::HashCombine(info.target_device, value); + onnxruntime::HashCombine(info.default_memory_arena_cfg, value); + onnxruntime::HashCombine(info.int8_calibration_table_name, value); + onnxruntime::HashCombine(info.model_cache_dir, value); onnxruntime::HashCombine(info.mem_limit, value); // Memory pointers - onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.alloc), value); - onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.free), value); - onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.empty_cache), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_alloc), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_free), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_empty_cache), value); // The default memory arena cfg is not used in hashing right now. return value; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h index 9274b5696185c..cce90f3ef82be 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h @@ -3,23 +3,27 @@ #pragma once +#include +#include +#include #include -#include -#include #include -#include #include +#include +#include +#include #include "flatbuffers/idl.h" #include "core/providers/migraphx/ort_trt_int8_cal_table.fbs.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/execution_provider.h" #include "core/common/path_string.h" +#include "core/framework/murmurhash3.h" namespace fs = std::filesystem; namespace onnxruntime { -bool IsGraphInput(const GraphViewer& graph, const std::string& name) { +inline bool IsGraphInput(const GraphViewer& graph, const std::string& name) { const auto& graph_inputs = graph.GetInputs(); std::vector input_names(graph_inputs.size()); std::transform(graph_inputs.begin(), graph_inputs.end(), input_names.begin(), [](auto in) { @@ -28,12 +32,12 @@ bool IsGraphInput(const GraphViewer& graph, const std::string& name) { return (std::find(input_names.begin(), input_names.end(), name) != input_names.end()); } -bool IsGraphInitializer(const GraphViewer& graph, const std::string& name, [[maybe_unused]] bool check_outer_scope = true) { +inline bool IsGraphInitializer(const GraphViewer& graph, const std::string& name, [[maybe_unused]] bool check_outer_scope = true) { const ONNX_NAMESPACE::TensorProto* initializer = nullptr; return graph.GetInitializedTensor(name, initializer); } -const Node* GetInputNode(const Node& node, int arg_index) { +inline const Node* GetInputNode(const Node& node, int arg_index) { int index = 0; for (auto nit = node.InputNodesBegin(); nit != node.InputNodesEnd(); ++nit, ++index) { if (index == arg_index) { @@ -44,7 +48,7 @@ const Node* GetInputNode(const Node& node, int arg_index) { return nullptr; } -std::size_t getNodeInputNum(const Node& node) { +inline std::size_t getNodeInputNum(const Node& node) { std::size_t node_num = 0; for (auto it = node.InputNodesBegin(); it != node.InputNodesEnd(); ++it) { node_num++; @@ -53,14 +57,14 @@ std::size_t getNodeInputNum(const Node& node) { return node_num; } -bool isInputNode(const Node* node, const std::string& name) { +inline bool isInputNode(const Node* node, const std::string& name) { auto outputs = node->OutputDefs(); return std::any_of(outputs.begin(), outputs.end(), [&](auto out) { return (out->Name() == name); }); } -bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector& input_nodes) { +inline bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector& input_nodes) { if (node == nullptr) { return false; } @@ -113,10 +117,10 @@ bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector return true; } -bool canEvalNodeArgument(const GraphViewer& graph, - const Node* node, - std::vector indices, - std::vector& input_nodes) { +inline bool canEvalNodeArgument(const GraphViewer& graph, + const Node* node, + std::vector indices, + std::vector& input_nodes) { input_nodes.clear(); std::vector in_nodes; for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit) { @@ -152,7 +156,7 @@ bool canEvalNodeArgument(const GraphViewer& graph, return true; } -float ConvertSinglePrecisionIEEE754ToFloat(uint32_t input) { +inline float ConvertSinglePrecisionIEEE754ToFloat(uint32_t input) { int s = (input >> 31) & 0x01; int e = ((input & 0x7f800000) >> 23) - 127; int p = -1; @@ -184,12 +188,12 @@ float ConvertSinglePrecisionIEEE754ToFloat(uint32_t input) { * Taken from the tensorRT EP to allow MIGraphX EP to reuse calibration tables for existing models * */ -bool ReadDynamicRange(const std::string file_name, - const bool is_calibration_table, - std::unordered_map& dynamic_range_map) { - std::ifstream infile(file_name, std::ios::binary | std::ios::in); - if (!infile) { +inline bool ReadDynamicRange(const std::filesystem::path& filename, + const bool is_calibration_table, + std::unordered_map& dynamic_range_map) { + std::ifstream infile{filename, std::ios::binary | std::ios::in}; + if (!infile.good()) { return false; } @@ -215,7 +219,7 @@ bool ReadDynamicRange(const std::string file_name, dynamic_range_map[tensor_name] = dynamic_range; } } else { - throw std::runtime_error("This is not a TensorRT generated calibration table " + file_name); + throw std::runtime_error("This is not a TensorRT generated calibration table " + filename.string()); } } } else { @@ -240,14 +244,111 @@ bool ReadDynamicRange(const std::string file_name, * Get cache by name * */ -std::string GetCachePath(const std::string& root, const std::string& name) { - if (root.empty()) { - return name; +inline std::filesystem::path GetCachePath(const std::filesystem::path& root, std::string_view name) { + return root.empty() ? std::filesystem::path{ToPathString(name)} : root / ToPathString(name); +} + +inline std::string GenerateGraphId(const GraphViewer& graph_viewer) { + HashValue model_hash; + + // find the top level graph + const Graph* cur_graph = &graph_viewer.GetGraph(); + while (cur_graph->IsSubgraph()) { + cur_graph = cur_graph->ParentGraph(); + } + + const Graph& main_graph = *cur_graph; + uint32_t hash[4] = {0, 0, 0, 0}; + + auto hash_str = [&hash](const std::string& str) { + MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); + }; + + // Use the model's file name instead of the entire path to avoid cache regeneration if a path changes + const fs::path path{main_graph.ModelPath()}; + + if (path.has_filename()) { + const auto model_name = path.filename().string(); + + LOGS_DEFAULT(INFO) << "Model name is '" << model_name << "'"; + // Ensure enough characters are hashed in case model names are too short + const size_t model_name_length = model_name.length(); + constexpr size_t hash_string_length = 500; + std::string repeat_model_name = model_name; + for (size_t i = model_name_length; i > 0 && i < hash_string_length; i += model_name_length) { + repeat_model_name += model_name; + } + hash_str(repeat_model_name); } else { - fs::path path = root; - path.append(name); - return path.string(); + LOGS_DEFAULT(INFO) << "Model path is empty"; + } + + // fingerprint current graph by hashing graph inputs + for (const auto* node_arg : graph_viewer.GetInputsIncludingInitializers()) { + hash_str(node_arg->Name()); + } + + // hashing outputs, inputs and inputs shapes of each node + const int number_of_ort_nodes = graph_viewer.NumberOfNodes(); + std::vector nodes_vector(number_of_ort_nodes); + std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); + const std::vector& node_index = graph_viewer.GetNodesInTopologicalOrder(); + for (const auto& index : nodes_vector) { + const auto& node = graph_viewer.GetNode(node_index[index]); + for (const auto* node_arg : node->OutputDefs()) { + if (node_arg != nullptr && node_arg->Exists()) { + hash_str(node_arg->Name()); + } + } + for (const auto* node_arg : node->InputDefs()) { + if (node_arg != nullptr && node_arg->Exists()) { + hash_str(node_arg->Name()); + if (node_arg->Shape() == nullptr) { + continue; + } + int dim_size = node_arg->Shape()->dim_size(); + for (int i = 0; i < dim_size; i++) { + hash_str(std::to_string(node_arg->Shape()->dim(i).dim_value())); + } + } + } + } + +#ifdef __linux__ + hash_str("LINUX"); +#elif defined(_WIN32) + hash_str("WINDOWS"); +#endif + + model_hash = hash[0] | static_cast(hash[1]) << 32; + + std::array s{}; + auto [ptr, ec] = std::to_chars(s.data(), s.data() + s.size(), model_hash, 16); + return std::string{s.data(), ptr}; +} + +inline std::string_view TrimLeft(std::string_view sv, int (*fn)(int) = std::isspace) { + return sv.substr(0, sv.end() - std::find_if(sv.begin(), sv.end(), [fn](int ch) { + return fn(ch); + })); +} + +inline std::string_view TrimRight(std::string_view sv, int (*fn)(int) = std::isspace) { + return sv.substr(sv.end() - std::find_if(sv.rbegin(), sv.rend(), [fn](int ch) { + return fn(ch); + }).base()); +} + +inline std::string_view Trim(std::string_view sv, int (*fn)(int) = std::isspace) { + return TrimRight(TrimLeft(sv, fn), fn); +} + +inline int ToInteger(const std::string_view sv) { + int result = 0; + if (auto [_, ec] = std::from_chars(sv.data(), sv.data() + sv.length(), result); ec == std::errc()) { + return result; } + ORT_THROW("invalid input for conversion to integer"); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_inc.h b/onnxruntime/core/providers/migraphx/migraphx_inc.h index 2b035b20f619f..49e838747892f 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_inc.h +++ b/onnxruntime/core/providers/migraphx/migraphx_inc.h @@ -6,3 +6,4 @@ #include #include #include +#include diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 923a39a7d2903..626758bce36d7 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -1,28 +1,36 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License + #include +#include +#include +#include +#include + +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#include +#include +#endif #include "core/providers/shared_library/provider_api.h" #include "core/providers/migraphx/migraphx_provider_factory.h" -#include "migraphx_execution_provider.h" -#include "migraphx_execution_provider_info.h" -#include "migraphx_provider_factory_creator.h" -#include "migraphx_allocator.h" -#include "gpu_data_transfer.h" +#include "core/providers/migraphx/migraphx_execution_provider.h" +#include "core/providers/migraphx/migraphx_execution_provider_info.h" +#include "core/providers/migraphx/migraphx_allocator.h" +#include "core/providers/migraphx/gpu_data_transfer.h" #include "core/framework/provider_options.h" #include "core/session/onnxruntime_c_api.h" -using namespace onnxruntime; - namespace onnxruntime { void InitializeRegistry(); void DeleteRegistry(); struct MIGraphXProviderFactory : IExecutionProviderFactory { - MIGraphXProviderFactory(const MIGraphXExecutionProviderInfo& info) : info_{info} {} - ~MIGraphXProviderFactory() override {} + explicit MIGraphXProviderFactory(MIGraphXExecutionProviderInfo info) : info_{std::move(info)} {} + ~MIGraphXProviderFactory() override = default; std::unique_ptr CreateProvider() override; @@ -35,11 +43,11 @@ std::unique_ptr MIGraphXProviderFactory::CreateProvider() { } struct ProviderInfo_MIGraphX_Impl final : ProviderInfo_MIGraphX { - std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) override { + std::unique_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, const char* name) override { return std::make_unique(device_id, name); } - std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) override { + std::unique_ptr CreateMIGraphXPinnedAllocator(OrtDevice::DeviceId device_id, const char* name) override { return std::make_unique(device_id, name); } @@ -61,14 +69,39 @@ struct ProviderInfo_MIGraphX_Impl final : ProviderInfo_MIGraphX { HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyDeviceToHost)); } - std::shared_ptr CreateMIGraphXAllocator(int16_t device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override { - return MIGraphXExecutionProvider::CreateMIGraphXAllocator(device_id, migx_mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg); + std::shared_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, + void* alloc_fn, void* free_fn, void* empty_cache_fn, const OrtArenaCfg* default_memory_arena_cfg) override { + if (alloc_fn != nullptr && free_fn != nullptr) { + AllocatorCreationInfo default_memory_info{ + [alloc_fn, free_fn, empty_cache_fn](OrtDevice::DeviceId id) { + return std::make_unique(id, HIP, alloc_fn, free_fn, empty_cache_fn); + }, + device_id, false}; + + return CreateAllocator(default_memory_info); + } + AllocatorCreationInfo default_memory_info{ + [](OrtDevice::DeviceId id) { + return std::make_unique(id, HIP); + }, + device_id, + true, + {default_memory_arena_cfg ? *default_memory_arena_cfg + : OrtArenaCfg(mem_limit, static_cast(arena_extend_strategy), + -1, -1, -1, -1L)}, + // make it stream aware + true}; + + // ROCM malloc/free is expensive so always use an arena + return CreateAllocator(default_memory_info); } } g_info; -struct MIGraphX_Provider : Provider { +struct MIGraphX_Provider final : Provider { void* GetInfo() override { return &g_info; } + virtual ~MIGraphX_Provider() = default; + std::shared_ptr CreateExecutionProviderFactory(int device_id) override { MIGraphXExecutionProviderInfo info; info.device_id = static_cast(device_id); @@ -76,72 +109,49 @@ struct MIGraphX_Provider : Provider { return std::make_shared(info); } + // Method uses ProviderOptions, and not OrtMIGraphXProviderOptions (obsolete) std::shared_ptr CreateExecutionProviderFactory(const void* provider_options) override { - auto& options = *reinterpret_cast(provider_options); - MIGraphXExecutionProviderInfo info; - info.device_id = static_cast(options.device_id); - info.target_device = "gpu"; - info.fp16_enable = options.migraphx_fp16_enable; - info.fp8_enable = options.migraphx_fp8_enable; - info.exhaustive_tune = options.migraphx_exhaustive_tune; - info.int8_enable = options.migraphx_int8_enable; - info.int8_calibration_table_name = ""; - if (options.migraphx_int8_calibration_table_name != nullptr) { - info.int8_calibration_table_name = options.migraphx_int8_calibration_table_name; - } - info.int8_use_native_calibration_table = options.migraphx_use_native_calibration_table != 0; - info.save_compiled_model = options.migraphx_save_compiled_model; - info.save_model_file = ""; - if (options.migraphx_save_model_path != nullptr) { - info.save_model_file = options.migraphx_save_model_path; + if (provider_options != nullptr) { + return std::make_shared( + MIGraphXExecutionProviderInfo{*static_cast(provider_options)}); } - info.load_compiled_model = options.migraphx_load_compiled_model; - info.load_model_file = ""; - if (options.migraphx_load_model_path != nullptr) { - info.load_model_file = options.migraphx_load_model_path; - } - info.arena_extend_strategy = static_cast(options.migraphx_arena_extend_strategy); - info.mem_limit = options.migraphx_mem_limit; - return std::make_shared(info); + return nullptr; } void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override { - auto internal_options = onnxruntime::MIGraphXExecutionProviderInfo::FromProviderOptions(options); - auto& migx_options = *reinterpret_cast(provider_options); - migx_options.device_id = internal_options.device_id; - migx_options.migraphx_fp16_enable = internal_options.fp16_enable; - migx_options.migraphx_fp8_enable = internal_options.fp8_enable; - migx_options.migraphx_int8_enable = internal_options.int8_enable; - migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune; - - char* dest = nullptr; - auto str_size = internal_options.int8_calibration_table_name.size(); - if (str_size == 0) { - migx_options.migraphx_int8_calibration_table_name = nullptr; + MIGraphXExecutionProviderInfo internal_options{options}; + const auto migx_options = static_cast(provider_options); + migx_options->device_id = internal_options.device_id; + migx_options->migraphx_fp16_enable = internal_options.fp16_enable; + migx_options->migraphx_fp8_enable = internal_options.fp8_enable; + migx_options->migraphx_int8_enable = internal_options.int8_enable; + migx_options->migraphx_exhaustive_tune = internal_options.exhaustive_tune; + + if (internal_options.int8_calibration_table_name.empty()) { + migx_options->migraphx_int8_calibration_table_name = nullptr; } else { - dest = new char[str_size + 1]; + auto str_size = internal_options.int8_calibration_table_name.size(); + auto dest = new char[str_size + 1]; #ifdef _MSC_VER strncpy_s(dest, str_size + 1, internal_options.int8_calibration_table_name.c_str(), str_size); #else strncpy(dest, internal_options.int8_calibration_table_name.c_str(), str_size); #endif dest[str_size] = '\0'; - migx_options.migraphx_int8_calibration_table_name = (const char*)dest; + migx_options->migraphx_int8_calibration_table_name = static_cast(dest); } - migx_options.migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table; + migx_options->migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table; - migx_options.migraphx_save_compiled_model = internal_options.save_compiled_model; - migx_options.migraphx_save_model_path = internal_options.save_model_file.c_str(); - migx_options.migraphx_load_compiled_model = internal_options.load_compiled_model; - migx_options.migraphx_load_model_path = internal_options.load_model_file.c_str(); - migx_options.migraphx_arena_extend_strategy = static_cast(internal_options.arena_extend_strategy); - migx_options.migraphx_mem_limit = internal_options.mem_limit; + migx_options->migraphx_arena_extend_strategy = static_cast(internal_options.arena_extend_strategy); + migx_options->migraphx_mem_limit = internal_options.mem_limit; } ProviderOptions GetProviderOptions(const void* provider_options) override { - auto& options = *reinterpret_cast(provider_options); - return onnxruntime::MIGraphXExecutionProviderInfo::ToProviderOptions(options); + return provider_options != nullptr ? MIGraphXExecutionProviderInfo{ + *static_cast(provider_options)} + .ToProviderOptions() + : ProviderOptions{}; } Status CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, @@ -152,19 +162,29 @@ struct MIGraphX_Provider : Provider { const OrtLogger& logger, std::unique_ptr& ep) override { ORT_UNUSED_PARAMETER(num_devices); - const ConfigOptions* config_options = &session_options.GetConfigOptions(); - - std::array configs_array = {&provider_options, config_options}; - OrtMIGraphXProviderOptions migraphx_options; - UpdateProviderOptions(&migraphx_options, provider_options); - - auto ep_factory = CreateExecutionProviderFactory(&migraphx_options); + const auto ep_factory = CreateExecutionProviderFactory(&provider_options); ep = ep_factory->CreateProvider(session_options, logger); - return Status::OK(); } void Initialize() override { +#ifdef _WIN32 + HMODULE module = nullptr; + if (GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | + GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, + static_cast(static_cast(InitializeRegistry)), + &module) != 0) { + std::vector pathBuf; + for (;;) { + pathBuf.resize(pathBuf.size() + MAX_PATH); + if (const auto writen = GetModuleFileNameW(module, pathBuf.data(), static_cast(pathBuf.size())); writen < pathBuf.size()) { + break; + } + } + std::filesystem::path path(pathBuf.begin(), pathBuf.end()); + SetDllDirectoryW(path.parent_path().native().c_str()); + } +#endif InitializeRegistry(); } diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h index d1c9457bafa0f..c23c9947c8d9b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h @@ -1,22 +1,23 @@ -// Copyright 2019 AMD AMDMIGraphX +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License -#include "core/framework/provider_options.h" -#include "onnxruntime_c_api.h" +#pragma once + +#include + +#include "core/framework/arena_extend_strategy.h" +#include "core/framework/ortdevice.h" namespace onnxruntime { class IAllocator; -class IDataTransfer; -struct IExecutionProviderFactory; -struct MIGraphXExecutionProviderInfo; -enum class ArenaExtendStrategy : int32_t; -struct MIGraphXExecutionProviderExternalAllocatorInfo; struct ProviderInfo_MIGraphX { - virtual std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) = 0; - virtual std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) = 0; + virtual std::unique_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, const char* name) = 0; + virtual std::unique_ptr CreateMIGraphXPinnedAllocator(OrtDevice::DeviceId device_id, const char* name) = 0; virtual void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, size_t count) = 0; virtual void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, size_t count) = 0; - virtual std::shared_ptr CreateMIGraphXAllocator(int16_t device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) = 0; + virtual std::shared_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t mem_limit, + ArenaExtendStrategy arena_extend_strategy, void* alloc_fn, void* free_fn, void* empty_cache_fn, const OrtArenaCfg* default_memory_arena_cfg) = 0; protected: ~ProviderInfo_MIGraphX() = default; // Can only be destroyed through a subclass instance diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h b/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h index 02d30ad0f6fbb..db169b9e2f5a9 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h @@ -6,6 +6,7 @@ #include #include "core/providers/providers.h" +#include "core/framework/provider_options.h" struct OrtMIGraphXProviderOptions; @@ -14,5 +15,6 @@ namespace onnxruntime { struct MIGraphXProviderFactoryCreator { static std::shared_ptr Create(int device_id); static std::shared_ptr Create(const OrtMIGraphXProviderOptions* options); + static std::shared_ptr Create(const ProviderOptions&); }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc index 6e492327a73a3..0baa8a1c67c67 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc @@ -1,17 +1,26 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include "migraphx_stream_handle.h" +#include +#include + +#include "core/providers/resource.h" +#include "core/providers/migraphx/migraphx_stream_handle.h" + +#define MIGRAPHX_RESOURCE_VERSION 1 namespace onnxruntime { -struct MIGraphXNotification : public synchronize::Notification { - MIGraphXNotification(Stream& s) : Notification(s) { +enum MIGraphXResource { + hip_stream_t = rocm_resource_offset +}; + +struct MIGraphXNotification : synchronize::Notification { + explicit MIGraphXNotification(Stream& s) : Notification(s) { HIP_CALL_THROW(hipEventCreateWithFlags(&event_, hipEventDisableTiming)); } - ~MIGraphXNotification() { + ~MIGraphXNotification() override { if (event_) HIP_CALL_THROW(hipEventDestroy(event_)); } @@ -21,19 +30,19 @@ struct MIGraphXNotification : public synchronize::Notification { HIP_CALL_THROW(hipEventRecord(event_, static_cast(GetStream().GetHandle()))); } - void wait_on_device(Stream& device_stream) { - ORT_ENFORCE(device_stream.GetDevice().Type() == OrtDevice::GPU, "Unexpected device:", - device_stream.GetDevice().ToString()); - // launch a wait command to the migraphx stream - HIP_CALL_THROW(hipStreamWaitEvent(static_cast(device_stream.GetHandle()), event_, 0)); - }; + void wait_on_device(Stream* device_stream) const { + if (device_stream != nullptr) { + ORT_ENFORCE(device_stream->GetDevice().Type() == OrtDevice::GPU, "Unexpected device:", device_stream->GetDevice().ToString()); + // launch a wait command to the migraphx stream + HIP_CALL_THROW(hipStreamWaitEvent(static_cast(device_stream->GetHandle()), event_, 0)); + } + } - void wait_on_host() { - // CUDA_CALL_THROW(cudaStreamSynchronize(stream_)); + void wait_on_host() const { HIP_CALL_THROW(hipEventSynchronize(event_)); } - hipEvent_t event_; + hipEvent_t event_{}; }; MIGraphXStream::MIGraphXStream(hipStream_t stream, @@ -41,15 +50,14 @@ MIGraphXStream::MIGraphXStream(hipStream_t stream, AllocatorPtr cpu_allocator, bool release_cpu_buffer_on_migraphx_stream) : Stream(stream, device), - cpu_allocator_(cpu_allocator), + cpu_allocator_(std::move(cpu_allocator)), release_cpu_buffer_on_migraphx_stream_(release_cpu_buffer_on_migraphx_stream) { } MIGraphXStream::~MIGraphXStream() { - ORT_IGNORE_RETURN_VALUE(CleanUpOnRunEnd()); + ORT_IGNORE_RETURN_VALUE(MIGraphXStream::CleanUpOnRunEnd()); if (own_stream_) { - auto* handle = GetHandle(); - if (handle) + if (auto* handle = GetHandle()) HIP_CALL_THROW(hipStreamDestroy(static_cast(handle))); } } @@ -87,12 +95,12 @@ struct CpuBuffersInfo { std::unique_ptr buffers; // CPU buffer buffers[i]. // Number of buffer points in "buffers". - size_t n_buffers; + size_t n_buffers{}; }; static void ReleaseCpuBufferCallback(void* raw_info) { std::unique_ptr info = std::make_unique(); - info.reset(reinterpret_cast(raw_info)); + info.reset(static_cast(raw_info)); for (size_t i = 0; i < info->n_buffers; ++i) { info->allocator->Free(info->buffers[i]); } @@ -124,29 +132,25 @@ Status MIGraphXStream::CleanUpOnRunEnd() { } void* MIGraphXStream::GetResource(int version, int id) const { - ORT_ENFORCE(version <= ORT_ROCM_RESOURCE_VERSION, "resource version unsupported!"); - void* resource{}; - switch (id) { - case RocmResource::hip_stream_t: - return reinterpret_cast(GetHandle()); - default: - break; + ORT_ENFORCE(version <= MIGRAPHX_RESOURCE_VERSION, "resource version unsupported!"); + if (id == hip_stream_t) { + return GetHandle(); } - return resource; + return nullptr; } // CPU Stream command handles void WaitMIGraphXNotificationOnDevice(Stream* stream, synchronize::Notification& notification) { - static_cast(¬ification)->wait_on_device(*stream); + dynamic_cast(¬ification)->wait_on_device(stream); } void WaitMIGraphXNotificationOnHost(Stream* /*stream*/, synchronize::Notification& notification) { - static_cast(¬ification)->wait_on_host(); + dynamic_cast(¬ification)->wait_on_host(); } void RegisterMIGraphXStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, const OrtDevice::DeviceType device_type, - AllocatorPtr cpu_allocator, + const AllocatorPtr& cpu_allocator, bool release_cpu_buffer_on_migraphx_stream, hipStream_t external_stream, bool use_existing_stream) { @@ -154,19 +158,20 @@ void RegisterMIGraphXStreamHandles(IStreamCommandHandleRegistry& stream_handle_r stream_handle_registry.RegisterWaitFn(device_type, device_type, WaitMIGraphXNotificationOnDevice); // wait migraphx notification on cpu ep stream_handle_registry.RegisterWaitFn(device_type, OrtDevice::CPU, WaitMIGraphXNotificationOnHost); - if (!use_existing_stream) + if (!use_existing_stream) { stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_migraphx_stream](const OrtDevice& device) { HIP_CALL_THROW(hipSetDevice(device.Id())); hipStream_t stream = nullptr; HIP_CALL_THROW(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); return std::make_unique(stream, device, cpu_allocator, release_cpu_buffer_on_migraphx_stream); }); - else + } else { stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_migraphx_stream, external_stream](const OrtDevice& device) { return std::make_unique(external_stream, device, cpu_allocator, release_cpu_buffer_on_migraphx_stream); }); + } } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h index 886103690c661..132ae5fc09d13 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h @@ -2,12 +2,15 @@ // Licensed under the MIT License. #pragma once + +#include +#include + #include "core/framework/stream_handles.h" -#include "migraphx_inc.h" -#include "migraphx_call.h" +#include "core/providers/migraphx/migraphx_inc.h" +#include "core/providers/migraphx/migraphx_call.h" namespace onnxruntime { -void WaitMIGraphXNotificationOnDevice(Stream* stream, synchronize::Notification& notification); struct MIGraphXStream : Stream { MIGraphXStream(hipStream_t stream, @@ -15,7 +18,7 @@ struct MIGraphXStream : Stream { AllocatorPtr cpu_allocator, bool release_cpu_buffer_on_migraphx_stream); - ~MIGraphXStream(); + ~MIGraphXStream() override; std::unique_ptr CreateNotification(size_t /*num_consumers*/) override; @@ -27,7 +30,7 @@ struct MIGraphXStream : Stream { bool own_stream_{true}; - virtual void* GetResource(int version, int id) const; + void* GetResource(int version, int id) const override; private: std::vector deferred_cpu_buffers_; @@ -36,8 +39,8 @@ struct MIGraphXStream : Stream { }; void RegisterMIGraphXStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, - const OrtDevice::DeviceType device_type, - AllocatorPtr cpu_allocator, + OrtDevice::DeviceType device_type, + const AllocatorPtr& cpu_allocator, bool release_cpu_buffer_on_migraphx_stream, hipStream_t external_stream, bool use_existing_stream); diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 71d51c4c2992d..a7fd83f10fe18 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -326,6 +326,12 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe const logging::Logger& logger); std::string GetEnvironmentVar(const std::string& var_name); +inline std::string GetEnvironmentVar(std::string_view var_name) { + return GetEnvironmentVar(std::string{var_name}); +} +inline std::string GetEnvironmentVar(const char* var_name) { + return GetEnvironmentVar(std::string{var_name}); +} namespace profiling { diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 031a4df59d83f..d690cf31072d2 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -790,12 +790,12 @@ Status LoadDynamicLibrary(onnxruntime::PathString library_name) { #endif #ifdef _WIN32 -std::string ToUTF8String(const std::wstring& s) { - return g_host->ToUTF8String(s); +std::string ToUTF8String(std::wstring_view s) { + return g_host->ToUTF8String(std::wstring{s}); } -std::wstring ToWideString(const std::string& s) { - return g_host->ToWideString(s); +std::wstring ToWideString(std::string_view s) { + return g_host->ToWideString(std::string{s}); } #endif // _WIN32 } // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/imp/graph.cc b/onnxruntime/core/providers/vitisai/imp/graph.cc index 20ae1cfbfa2c1..c6bf29dafa184 100644 --- a/onnxruntime/core/providers/vitisai/imp/graph.cc +++ b/onnxruntime/core/providers/vitisai/imp/graph.cc @@ -15,7 +15,7 @@ #include "vaip/node.h" #include "vaip/node_arg.h" - +#include "./tensor_proto.h" namespace vaip { struct NodeEdgeT { @@ -286,7 +286,14 @@ Model* model_clone(const Model& original_model, int64_t external_data_threshold) cloned_tensor->add_dims(dim); size = size * dim; } - if (size >= external_data_threshold) { + auto ORT_MEM_ADDR_tag = process_ext_address(*original_tensor); + if (!ORT_MEM_ADDR_tag.empty()) { + cloned_tensor->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + auto external_data = cloned_tensor->mutable_external_data(); + auto p = external_data->Add(); + *p->mutable_key() = "location"; + *p->mutable_value() = std::string("<") + graph_ptr; + } else if (size >= external_data_threshold) { cloned_tensor->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); auto external_data = cloned_tensor->mutable_external_data(); auto p = external_data->Add(); diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc index bb942c69003a1..2f1478bf1326b 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc @@ -10,7 +10,7 @@ namespace vaip { using namespace onnxruntime; -static gsl::span process_ext_address(const ONNX_NAMESPACE::TensorProto& tensor) { +gsl::span process_ext_address(const ONNX_NAMESPACE::TensorProto& tensor) { auto tensor_proto = const_cast(&tensor); auto file = std::string(); uintptr_t offset = 0; diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.h b/onnxruntime/core/providers/vitisai/imp/tensor_proto.h index 73015d3411a54..a7c90ac18b44e 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.h +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.h @@ -37,4 +37,5 @@ ONNX_NAMESPACE::TensorProto* tensor_proto_new_fp16(const std::string& name, cons const std::vector& data); ONNX_NAMESPACE::TensorProto* tensor_proto_new_doubles(const std::string& name, const std::vector& shape, const std::vector& data); +gsl::span process_ext_address(const ONNX_NAMESPACE::TensorProto& tensor); } // namespace vaip diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 4c7b4d7b29c2f..88d84e95b406c 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1378,9 +1378,16 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetOverridableInitializerTypeInfo, _In_ cons return GetNodeDefTypeInfoHelper(sess, get_overridable_initializers_fn, index, out); } -char* onnxruntime::StrDup(const std::string& str, OrtAllocator* allocator) { - char* output_string = reinterpret_cast(allocator->Alloc(allocator, str.size() + 1)); - memcpy(output_string, str.c_str(), str.size()); +char* onnxruntime::StrDup(std::string_view str, OrtAllocator* allocator) { + char* output_string = static_cast(allocator->Alloc(allocator, str.size() + 1)); + memcpy(output_string, str.data(), str.size()); + output_string[str.size()] = '\0'; + return output_string; +} + +wchar_t* onnxruntime::StrDup(std::wstring_view str, OrtAllocator* allocator) { + auto* output_string = static_cast(allocator->Alloc(allocator, (str.size() + 1) * sizeof(wchar_t))); + memcpy(output_string, str.data(), str.size() * sizeof(wchar_t)); output_string[str.size()] = '\0'; return output_string; } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 01b70db6d940e..ee59ff2ab4932 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -2108,8 +2108,13 @@ std::shared_ptr NvProviderFactoryCreator::Create( return nullptr; } -std::shared_ptr MIGraphXProviderFactoryCreator::Create(const OrtMIGraphXProviderOptions* provider_options) { - return s_library_migraphx.Get().CreateExecutionProviderFactory(provider_options); +std::shared_ptr MIGraphXProviderFactoryCreator::Create(const ProviderOptions& provider_options) { + return s_library_migraphx.Get().CreateExecutionProviderFactory(&provider_options); +} + +std::shared_ptr MIGraphXProviderFactoryCreator::Create(const OrtMIGraphXProviderOptions* options) { + const auto provider_options{s_library_migraphx.Get().GetProviderOptions(options)}; + return s_library_migraphx.Get().CreateExecutionProviderFactory(&provider_options); } // Adapter to convert the legacy OrtOpenVINOProviderOptions to ProviderOptions diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 18a463ef69943..48d52ae3cf428 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -101,6 +101,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, VitisAI, CoreML, NvTensorRtRtx, // TensorRt EP for RTX GPUs. + MIGraphX }; struct EpToAppend { @@ -109,7 +110,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, const char* canonical_name = nullptr; }; - static std::array supported_eps = { + static std::array supported_eps = { EpToAppend{EpID::DML, "DML", kDmlExecutionProvider}, EpToAppend{EpID::QNN, "QNN", kQnnExecutionProvider}, EpToAppend{EpID::OpenVINO, "OpenVINO", kOpenVINOExecutionProvider}, @@ -121,7 +122,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, EpToAppend{EpID::JS, "JS", kJsExecutionProvider}, EpToAppend{EpID::VitisAI, "VitisAI", kVitisAIExecutionProvider}, EpToAppend{EpID::CoreML, "CoreML", kCoreMLExecutionProvider}, - EpToAppend{EpID::NvTensorRtRtx, "NvTensorRtRtx", kNvTensorRTRTXExecutionProvider}}; + EpToAppend{EpID::NvTensorRtRtx, "NvTensorRtRtx", kNvTensorRTRTXExecutionProvider}, + EpToAppend{EpID::MIGraphX, "MIGraphX", kMIGraphXExecutionProvider}}; ProviderOptions provider_options; OrtStatus* status = ParseProviderOptions(provider_options_keys, @@ -279,6 +281,14 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options, &(options->value))); #else status = create_not_supported_status(); +#endif + break; + } + case EpID::MIGraphX: { +#if defined(USE_MIGRAPHX) || defined(USE_MIGRAPHX_PROVIDER_INTERFACE) + options->provider_factories.push_back(MIGraphXProviderFactoryCreator::Create(provider_options)); +#else + status = create_not_supported_status(); #endif break; } diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 590e1ef3cdbdb..1934e0eda7956 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -196,9 +196,9 @@ void CudaToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { GetProviderInfo_CUDA().cudaMemcpy_DeviceToHost(dst, src, num_bytes); } -const std::unordered_map* GetCudaToHostMemCpyFunction() { +const std::unordered_map* GetCudaToHostMemCpyFunction(const OrtDevice& device) { static std::unordered_map map{ - {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, 0}, CudaToCpuMemCpy}, + {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, device.Id()}, CudaToCpuMemCpy}, }; return ↦ @@ -246,6 +246,7 @@ std::unique_ptr GetGPUDataTransfer() { #endif #ifdef USE_MIGRAPHX + void CpuToMIGraphXMemCpy(void* dst, const void* src, size_t num_bytes) { GetProviderInfo_MIGraphX().MIGraphXMemcpy_HostToDevice(dst, src, num_bytes); } @@ -256,7 +257,7 @@ void MIGraphXToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { const std::unordered_map* GetMIGraphXToHostMemCpyFunction(const OrtDevice& device) { static std::unordered_map map{ - {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, 0}, MIGraphXToCpuMemCpy}, + {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device.Id()}, MIGraphXToCpuMemCpy}, }; return ↦ @@ -270,7 +271,11 @@ AllocatorPtr GetMIGraphXAllocator(OrtDevice::DeviceId id) { if (id_to_allocator_map->find(id) == id_to_allocator_map->end()) { // TODO: Expose knobs so that users can set fields associated with OrtArenaCfg so that we can pass it to the following method - id_to_allocator_map->insert({id, GetProviderInfo_MIGraphX().CreateMIGraphXAllocator(id, gpu_mem_limit, arena_extend_strategy, migx_external_allocator_info, nullptr)}); + id_to_allocator_map->insert( + {id, GetProviderInfo_MIGraphX().CreateMIGraphXAllocator( + id, gpu_mem_limit, arena_extend_strategy, + migraphx::external::alloc_fn, migraphx::external::free_fn, migraphx::external::empty_cache_fn, + nullptr)}); } return (*id_to_allocator_map)[id]; @@ -374,9 +379,9 @@ void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { D3D12_RESOURCE_STATE_UNORDERED_ACCESS); } -const std::unordered_map* GetDmlToHostMemCpyFunction() { +const std::unordered_map* GetDmlToHostMemCpyFunction(const OrtDevice& device) { static std::unordered_map map{ - {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, 0}, DmlToCpuMemCpy}, + {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, device.Id()}, DmlToCpuMemCpy}, }; return ↦ @@ -444,9 +449,9 @@ void RocmToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { GetProviderInfo_ROCM().rocmMemcpy_DeviceToHost(dst, src, num_bytes); } -const std::unordered_map* GetRocmToHostMemCpyFunction() { +const std::unordered_map* GetRocmToHostMemCpyFunction(const OrtDevice& device) { static std::unordered_map map{ - {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, 0}, RocmToCpuMemCpy}, + {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device.Id()}, RocmToCpuMemCpy}, }; return ↦ diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.h b/onnxruntime/python/onnxruntime_pybind_mlvalue.h index 7b65c0aae45c1..eba783d826212 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.h +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.h @@ -74,7 +74,7 @@ void CpuToCudaMemCpy(void* dst, const void* src, size_t num_bytes); void CudaToCpuMemCpy(void* dst, const void* src, size_t num_bytes); -const std::unordered_map* GetCudaToHostMemCpyFunction(); +const std::unordered_map* GetCudaToHostMemCpyFunction(const OrtDevice&); bool IsCudaDeviceIdValid(const onnxruntime::logging::Logger& logger, int id); @@ -92,7 +92,7 @@ void CpuToDmlMemCpy(void* dst, const void* src, size_t num_bytes); void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes); -const std::unordered_map* GetDmlToHostMemCpyFunction(); +const std::unordered_map* GetDmlToHostMemCpyFunction(const OrtDevice&); #endif @@ -102,7 +102,7 @@ void CpuToMIGraphXMemCpy(void* dst, const void* src, size_t num_bytes); void MIGraphXToCpuMemCpy(void* dst, const void* src, size_t num_bytes); -const std::unordered_map* GetMIGraphXToHostMemCpyFunction(); +const std::unordered_map* GetMIGraphXToHostMemCpyFunction(const OrtDevice&); AllocatorPtr GetMIGraphXAllocator(OrtDevice::DeviceId id); @@ -132,7 +132,7 @@ void CpuToRocmMemCpy(void* dst, const void* src, size_t num_bytes); void RocmToCpuMemCpy(void* dst, const void* src, size_t num_bytes); -const std::unordered_map* GetRocmToHostMemCpyFunction(); +const std::unordered_map* GetRocmToHostMemCpyFunction(const OrtDevice&); #endif diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index 7234543eb14de..1fe7ab0884f9c 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -421,21 +421,39 @@ void addOrtValueMethods(pybind11::module& m) { // Converts Tensor into a numpy array .def("numpy", [](const OrtValue* ml_value) -> py::object { ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are convertible to Numpy objects"); - + [[maybe_unused]] const auto& device = ml_value->Get().Location().device; +#ifdef _MSC_VER +// The switch statement may only contain the 'default' label. In such a case, the MSVC compiler +// will warn about it, and since the warnings are treated as errors, the compilation will break. +// Below pragmas turn off warning generation for this switch only. +#pragma warning(push) +#pragma warning(disable : 4065) +#endif + switch (device.Vendor()) { #ifdef USE_CUDA - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetCudaToHostMemCpyFunction()); -#elif USE_ROCM - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetRocmToHostMemCpyFunction()); -#elif USE_CANN - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetCannToHostMemCpyFunction()); -#elif USE_DML - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetDmlToHostMemCpyFunction()); -#elif USE_MIGRAPHX - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetMIGraphXToHostMemCpyFunction()); -#else - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, nullptr); + case OrtDevice::VendorIds::NVIDIA: + return GetPyObjFromTensor(*ml_value, nullptr, GetCudaToHostMemCpyFunction(device)); +#endif +#ifdef USE_CANN + case OrtDevice::VendorIds::HUAWEI: + return GetPyObjFromTensor(*ml_value, nullptr, GetCannToHostMemCpyFunction()); #endif - return obj; }) + +#ifdef USE_DML + case OrtDevice::VendorIds::MICROSOFT: + return GetPyObjFromTensor(*ml_value, nullptr, GetDmlToHostMemCpyFunction(device)); +#endif +#ifdef USE_MIGRAPHX + case OrtDevice::VendorIds::AMD: + return GetPyObjFromTensor(*ml_value, nullptr, GetMIGraphXToHostMemCpyFunction(device)); +#endif + default: + return GetPyObjFromTensor(*ml_value, nullptr, nullptr); + } +#ifdef _MSC_VER +#pragma warning(pop) +#endif + }) #if defined(ENABLE_DLPACK) .def("to_dlpack", [](OrtValue* ort_value) -> py::object { return py::reinterpret_steal(ToDlpack(*ort_value)); }, "Returns a DLPack representing the tensor. This method does not copy the pointer shape, " diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 03ad0185d1394..24554560b4dde 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -979,135 +979,10 @@ static std::shared_ptr CreateExecutionProviderFactory #endif } else if (type == kMIGraphXExecutionProvider) { #if defined(USE_MIGRAPHX) || defined(USE_MIGRAPHX_PROVIDER_INTERFACE) - std::string calibration_table; - std::string save_model_path; - std::string load_model_path; auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { - OrtMIGraphXProviderOptions params{ - 0, - 0, - 0, - 0, - 0, - nullptr, - 1, - "./compiled_model.mxr", - 1, - "./compiled_model.mxr", - 1, - SIZE_MAX, - 0}; - for (auto option : it->second) { - if (option.first == "device_id") { - if (!option.second.empty()) { - params.device_id = std::stoi(option.second); - } else { - ORT_THROW("[ERROR] [MIGraphX] The value for the key 'device_id' should be a number i.e. '0'.\n"); - } - } else if (option.first == "migraphx_fp16_enable") { - if (option.second == "True" || option.second == "true") { - params.migraphx_fp16_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_fp16_enable = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_fp16_enable' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == "migraphx_fp8_enable") { - if (option.second == "True" || option.second == "true") { - params.migraphx_fp8_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_fp8_enable = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_fp8_enable' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == "migraphx_int8_enable") { - if (option.second == "True" || option.second == "true") { - params.migraphx_int8_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_int8_enable = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_int8_enable' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == "migraphx_int8_calibration_table_name") { - if (!option.second.empty()) { - calibration_table = option.second; - params.migraphx_int8_calibration_table_name = calibration_table.c_str(); - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_int8_calibration_table_name' should be a " - "file name i.e. 'cal_table'.\n"); - } - } else if (option.first == "migraphx_use_native_calibration_table") { - if (option.second == "True" || option.second == "true") { - params.migraphx_use_native_calibration_table = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_use_native_calibration_table = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_use_native_calibration_table' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == "migraphx_save_compiled_model") { - if (option.second == "True" || option.second == "true") { - params.migraphx_fp16_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_fp16_enable = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_save_compiled_model' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == "migraphx_save_model_path") { - if (!option.second.empty()) { - save_model_path = option.second; - params.migraphx_save_model_path = save_model_path.c_str(); - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_save_model_name' should be a " - "file name i.e. 'compiled_model.mxr'.\n"); - } - } else if (option.first == "migraphx_load_compiled_model") { - if (option.second == "True" || option.second == "true") { - params.migraphx_fp16_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_fp16_enable = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_load_compiled_model' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == "migraphx_load_model_path") { - if (!option.second.empty()) { - load_model_path = option.second; - params.migraphx_load_model_path = load_model_path.c_str(); - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_load_model_name' should be a " - "file name i.e. 'compiled_model.mxr'.\n"); - } - } else if (option.first == "migraphx_exhaustive_tune") { - if (option.second == "True" || option.second == "true") { - params.migraphx_exhaustive_tune = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_exhaustive_tune = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_exhaustive_tune' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else { - ORT_THROW("Invalid MIGraphX EP option: ", option.first); - } - } if (std::shared_ptr migraphx_provider_factory = - onnxruntime::MIGraphXProviderFactoryCreator::Create(¶ms)) { + onnxruntime::MIGraphXProviderFactoryCreator::Create(it->second)) { return migraphx_provider_factory; } } else { @@ -1917,7 +1792,6 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra vendor = OrtDevice::VendorIds::HUAWEI; #endif } - return OrtDevice(type, mem_type, vendor, device_id); }), R"pbdoc(Constructor with vendor_id defaulted to 0 for backward compatibility.)pbdoc") diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.cc b/onnxruntime/python/onnxruntime_pybind_state_common.cc index 4b9e012764885..cccdb9d23900a 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.cc +++ b/onnxruntime/python/onnxruntime_pybind_state_common.cc @@ -47,7 +47,11 @@ onnxruntime::ArenaExtendStrategy arena_extend_strategy = onnxruntime::ArenaExten #endif #ifdef USE_MIGRAPHX -onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo migx_external_allocator_info{}; +namespace migraphx::external { +void* alloc_fn{nullptr}; +void* free_fn{nullptr}; +void* empty_cache_fn{nullptr}; +} // namespace migraphx::external #endif #if defined(ENABLE_DLPACK) diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 706c151936192..b4a33e798f942 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -40,7 +40,7 @@ struct OrtStatus { #define BACKEND_PROC "CPU" #endif -#if USE_DNNL +#ifdef USE_DNNL #define BACKEND_DNNL "-DNNL" #else #define BACKEND_DNNL "" @@ -226,9 +226,14 @@ extern onnxruntime::ArenaExtendStrategy arena_extend_strategy; namespace onnxruntime { ProviderInfo_MIGraphX* TryGetProviderInfo_MIGraphX(); ProviderInfo_MIGraphX& GetProviderInfo_MIGraphX(); -namespace python { -extern onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo migx_external_allocator_info; -} // namespace python +namespace python::migraphx::external { +extern void* alloc_fn; +extern void* free_fn; +extern void* empty_cache_fn; +inline bool UseExternalAllocator() { + return alloc_fn != nullptr && free_fn != nullptr; +} +} // namespace python::migraphx::external } // namespace onnxruntime #endif diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 2e4aa3923b649..bae7a14908916 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -80,21 +80,7 @@ std::unique_ptr TensorrtExecutionProviderWithOptions(const O std::unique_ptr DefaultMIGraphXExecutionProvider() { #ifdef USE_MIGRAPHX - OrtMIGraphXProviderOptions params{ - 0, - 0, - 0, - 0, - 0, - nullptr, - 1, - "./compiled_model.mxr", - 1, - "./compiled_model.mxr", - 1, - SIZE_MAX, - 0}; - return MIGraphXProviderFactoryCreator::Create(¶ms)->CreateProvider(); + return MIGraphXProviderFactoryCreator::Create(ProviderOptions{})->CreateProvider(); #else return nullptr; #endif @@ -102,7 +88,7 @@ std::unique_ptr DefaultMIGraphXExecutionProvider() { std::unique_ptr MIGraphXExecutionProviderWithOptions(const OrtMIGraphXProviderOptions* params) { #ifdef USE_MIGRAPHX - if (auto factory = MIGraphXProviderFactoryCreator::Create(params)) + if (const auto factory = MIGraphXProviderFactoryCreator::Create(params); factory != nullptr) return factory->CreateProvider(); #else ORT_UNUSED_PARAMETER(params); diff --git a/setup.py b/setup.py index 5ab1ac5b840d4..6bfb53329f319 100644 --- a/setup.py +++ b/setup.py @@ -68,6 +68,7 @@ def parse_arg_remove_string(argv, arg_name_equal): is_cuda_version_12 = cuda_version.startswith("12.") elif parse_arg_remove_boolean(sys.argv, "--use_migraphx"): is_migraphx = True + package_name = "onnxruntime-migraphx" elif parse_arg_remove_boolean(sys.argv, "--use_openvino"): is_openvino = True package_name = "onnxruntime-openvino" @@ -90,8 +91,6 @@ def parse_arg_remove_string(argv, arg_name_equal): is_qnn = True package_name = "onnxruntime-qnn" qnn_version = parse_arg_remove_string(sys.argv, "--qnn_version=") -elif is_migraphx: - package_name = "onnxruntime-migraphx" if not nightly_build else "ort-migraphx-nightly" # PEP 513 defined manylinux1_x86_64 and manylinux1_i686 # PEP 571 defined manylinux2010_x86_64 and manylinux2010_i686 @@ -283,7 +282,6 @@ def run(self): self._rewrite_ld_preload_tensorrt(to_preload_tensorrt) self._rewrite_ld_preload_tensorrt(to_preload_nv_tensorrt_rtx) self._rewrite_ld_preload(to_preload_cann) - else: pass @@ -412,6 +410,7 @@ def finalize_options(self): libs.extend(["onnxruntime_providers_nv_tensorrt_rtx.dll"]) libs.extend(["onnxruntime_providers_openvino.dll"]) libs.extend(["onnxruntime_providers_cuda.dll"]) + libs.extend(["onnxruntime_providers_migraphx.dll"]) libs.extend(["onnxruntime_providers_vitisai.dll"]) libs.extend(["onnxruntime_providers_qnn.dll"]) # DirectML Libs @@ -435,6 +434,26 @@ def finalize_options(self): libs.extend(qnn_deps) if nightly_build: libs.extend(["onnxruntime_pywrapper.dll"]) + migraphx_deps = [ + "amd_comgr0602.dll", + "amd_comgr0604.dll", + "amd_comgr0700.dll", + "hiprtc0602.dll", + "hiprtc0604.dll", + "hiprtc0700.dll", + "hiprtc-builtins0602.dll", + "hiprtc-builtins0604.dll", + "hiprtc-builtins0700.dll", + "migraphx-hiprtc-driver.exe", + "migraphx.dll", + "migraphx_c.dll", + "migraphx_cpu.dll", + "migraphx_device.dll", + "migraphx_gpu.dll", + "migraphx_onnx.dll", + "migraphx_tf.dll", + ] + libs.extend(migraphx_deps) if is_manylinux: if is_openvino: diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 561a76be5fa89..0d51f66df33aa 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -723,8 +723,6 @@ def generate_build_tree( cmake_args += ["-Donnxruntime_ENABLE_WEBASSEMBLY_RELAXED_SIMD=ON"] if args.use_migraphx: cmake_args.append("-Donnxruntime_MIGRAPHX_HOME=" + migraphx_home) - cmake_args.append("-Donnxruntime_ROCM_HOME=" + rocm_home) - cmake_args.append("-Donnxruntime_ROCM_VERSION=" + args.rocm_version) if args.use_tensorrt: cmake_args.append("-Donnxruntime_TENSORRT_HOME=" + tensorrt_home) @@ -1994,6 +1992,7 @@ def build_nuget_package( use_winml, use_qnn, use_dml, + use_migraphx, enable_training_apis, msbuild_extra_options, ): @@ -2031,6 +2030,9 @@ def build_nuget_package( elif use_tensorrt: execution_provider = "/p:ExecutionProvider=tensorrt" package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.TensorRT" + elif use_migraphx: + execution_provider = "/p:ExecutionProvider=migraphx" + package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.MIGraphX" elif use_dnnl: execution_provider = "/p:ExecutionProvider=dnnl" package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.DNNL" @@ -2622,6 +2624,7 @@ def main(): getattr(args, "use_winml", False), args.use_qnn, getattr(args, "use_dml", False), + args.use_migraphx, args.enable_training_apis, normalize_arg_list(args.msbuild_extra_options), ) diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index 211cb7a2a8a75..ead240a7cef1b 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -22,6 +22,8 @@ def get_package_name(os, cpu_arch, ep, is_training_package): pkg_name += "-tensorrt" elif ep == "rocm": pkg_name += "-rocm" + elif ep == "migraphx": + pkg_name += "-migraphx" elif os == "linux": pkg_name += "-linux-" pkg_name += cpu_arch @@ -31,6 +33,8 @@ def get_package_name(os, cpu_arch, ep, is_training_package): pkg_name += "-tensorrt" elif ep == "rocm": pkg_name += "-rocm" + elif ep == "migraphx": + pkg_name += "-migraphx" elif os == "osx": pkg_name = "onnxruntime-osx-" + cpu_arch return pkg_name @@ -44,7 +48,11 @@ def get_package_name(os, cpu_arch, ep, is_training_package): def is_this_file_needed(ep, filename, package_name): if package_name == "Microsoft.ML.OnnxRuntime.Gpu": return False - return (ep != "cuda" or "cuda" in filename) and (ep != "tensorrt" or "cuda" not in filename) + return ( + (ep != "cuda" or "cuda" in filename) + and (ep != "tensorrt" or "cuda" not in filename) + and (ep != "migraphx" or "migraphx" not in filename) + ) # nuget_artifacts_dir: the directory with uncompressed C API tarball/zip files @@ -138,7 +146,7 @@ def parse_arguments(): required=False, default="None", type=str, - choices=["cuda", "dnnl", "openvino", "tensorrt", "snpe", "qnn", "None"], + choices=["cuda", "dnnl", "openvino", "migraphx", "tensorrt", "snpe", "qnn", "None"], help="The selected execution provider for this build.", ) parser.add_argument("--sdk_info", required=False, default="", type=str, help="dependency SDK information.") @@ -182,6 +190,8 @@ def generate_description(line_list, package_name): description = "This package contains Linux native shared library artifacts for ONNX Runtime with CUDA." elif "Microsoft.ML.OnnxRuntime.Gpu.Windows" in package_name: description = "This package contains Windows native shared library artifacts for ONNX Runtime with CUDA." + elif "Microsoft.ML.OnnxRuntime.MIGraphX" in package_name: + description = "This package contains native shared library artifacts for ONNX Runtime with MIGraphX." elif "Intel.ML.OnnxRuntime" in package_name: description = "This package contains native shared library artifacts for ONNX Runtime with OpenVINO." elif "Microsoft.ML.OnnxRuntime" in package_name: # This is a Microsoft.ML.OnnxRuntime.* package @@ -359,6 +369,7 @@ def generate_files(line_list, args): is_windowsai_package = args.package_name == "Microsoft.AI.MachineLearning" is_snpe_package = args.package_name == "Microsoft.ML.OnnxRuntime.Snpe" is_qnn_package = args.package_name == "Microsoft.ML.OnnxRuntime.QNN" + is_migraphx_package = args.package_name == "Microsoft.ML.OnnxRuntime.MIGraphX" is_training_package = args.package_name in [ "Microsoft.ML.OnnxRuntime.Training", "Microsoft.ML.OnnxRuntime.Training.Gpu", @@ -384,6 +395,24 @@ def generate_files(line_list, args): "openvino_ep_shared_lib": "onnxruntime_providers_openvino.dll", "cuda_ep_shared_lib": "onnxruntime_providers_cuda.dll", "qnn_ep_shared_lib": "onnxruntime_providers_qnn.dll", + "migraphx_ep_shared_lib": "onnxruntime_providers_migraphx.dll", + "amd_comgr0602": "amd_comgr0602.dll", + "amd_comgr0604": "amd_comgr0604.dll", + "amd_comgr0700": "amd_comgr0700.dll", + "hiprtc0602": "hiprtc0602.dll", + "hiprtc0604": "hiprtc0604.dll", + "hiprtc0700": "hiprtc0700.dll", + "hiprtc-builtins0602": "hiprtc-builtins0602.dll", + "hiprtc-builtins0604": "hiprtc-builtins0604.dll", + "hiprtc-builtins0700": "hiprtc-builtins0700.dll", + "migraphx-hiprtc-driver": "migraphx-hiprtc-driver.exe", + "migraphx": "migraphx.dll", + "migraphx_c": "migraphx_c.dll", + "migraphx_cpu": "migraphx_cpu.dll", + "migraphx_device": "migraphx_device.dll", + "migraphx_gpu": "migraphx_gpu.dll", + "migraphx_onnx": "migraphx_onnx.dll", + "migraphx_tf": "migraphx_tf", "onnxruntime_perf_test": "onnxruntime_perf_test.exe", "onnx_test_runner": "onnx_test_runner.exe", } @@ -402,6 +431,7 @@ def generate_files(line_list, args): "openvino_ep_shared_lib": "libonnxruntime_providers_openvino.so", "cuda_ep_shared_lib": "libonnxruntime_providers_cuda.so", "rocm_ep_shared_lib": "libonnxruntime_providers_rocm.so", + "migraphx_ep_shared_lib": "libonnxruntime_providers_migraphx.so", "onnxruntime_perf_test": "onnxruntime_perf_test", "onnx_test_runner": "onnx_test_runner", } @@ -421,7 +451,7 @@ def generate_files(line_list, args): include_dir = f"{build_dir}\\native\\include" # Sub.Gpu packages do not include the onnxruntime headers - if args.package_name != "Microsoft.ML.OnnxRuntime.Gpu": + if args.package_name != "Microsoft.ML.OnnxRuntime.Gpu" and args.package_name != "Microsoft.ML.OnnxRuntime.MIGraphX": files_list.append( "' ) + if args.execution_provider == "migraphx": + files_list.append( + "' + ) + files_list.append( + "' + ) + + if is_windows_build: + native_build_path = Path(args.native_build_path) + + def _files_list_append(key: str): + path = native_build_path / nuget_dependencies[key] + if path.exists(): + files_list.append( + "' + ) + + _files_list_append("amd_comgr0602") + _files_list_append("amd_comgr0604") + _files_list_append("amd_comgr0700") + _files_list_append("hiprtc0602") + _files_list_append("hiprtc0604") + _files_list_append("hiprtc0700") + _files_list_append("hiprtc-builtins0602") + _files_list_append("hiprtc-builtins0604") + _files_list_append("hiprtc-builtins0700") + _files_list_append("migraphx-hiprtc-driver") + _files_list_append("migraphx") + _files_list_append("migraphx_c") + _files_list_append("migraphx_cpu") + _files_list_append("migraphx_device") + _files_list_append("migraphx_gpu") + _files_list_append("migraphx_onnx") + _files_list_append("migraphx_tf") + if is_dml_package: files_list.append( "