Skip to content

Commit 8056966

Browse files
apwojcikurpetkov-amdTedThemistokleousTed Themistokleousskottmckay
authored andcommitted
[MIGraphX EP] Syncing AMD changes upstream (microsoft#25583)
A set of changes required for WCR/WindowsML that were added to the MIGraphX Execution provider. The development was done in the ROCm repository, now we want to sync with the main branch with a single drop. The PR incorporates the review comments from the previous closed PR microsoft#25338. Motivation and Context Fixes, changes, and updates to MIGraphX EP that have been done for ROCm development. Pushing this back upstream to ensure mainline onnxruntime is using the latest changes. Moving forward, MIGraphX EP will be cut from the latest official release tag as a base point while also adding additional features that will be contributed back. --------- Co-authored-by: urpetkov-amd <[email protected]> Co-authored-by: Ted Themistokleous <[email protected]> Co-authored-by: Ted Themistokleous <[email protected]> Co-authored-by: Scott McKay <[email protected]>
1 parent 6ec439c commit 8056966

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1108
-859
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# build, distribute, and bins (+ python proto bindings)
2+
build.*/
23
build
34
build_*/
45
.build_debug/*

cmake/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ include(CheckLanguage)
2929
include(CMakeDependentOption)
3030
include(FetchContent)
3131
include(CheckFunctionExists)
32+
include(CheckSymbolExists)
3233
include(GNUInstallDirs) # onnxruntime_providers_* require CMAKE_INSTALL_* variables
3334

3435
# TODO: update this once all system adapt c++20

cmake/onnxruntime_providers_migraphx.cmake

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,11 @@
22
# Licensed under the MIT License.
33

44
add_definitions(-DUSE_MIGRAPHX=1)
5-
set(BUILD_LIBRARY_ONLY 1)
6-
add_definitions("-DONNX_ML=1")
7-
add_definitions("-DONNX_NAMESPACE=onnx")
8-
include_directories(${protobuf_SOURCE_DIR} ${eigen_SOURCE_DIR})
9-
set(MIGRAPHX_ROOT ${onnxruntime_MIGRAPHX_HOME})
10-
include_directories(${onnx_SOURCE_DIR})
5+
include_directories(${protobuf_SOURCE_DIR} ${eigen_SOURCE_DIR} ${onnx_SOURCE_DIR})
116
set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
12-
if ( CMAKE_COMPILER_IS_GNUCC )
7+
if (CMAKE_COMPILER_IS_GNUCC)
138
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter -Wno-missing-field-initializers")
149
endif()
15-
set(CXX_VERSION_DEFINED TRUE)
16-
set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS})
17-
if ( CMAKE_COMPILER_IS_GNUCC )
18-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter")
19-
endif()
2010

2111
# Add search paths for default rocm installation
2212
list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hcc /opt/rocm/hip /opt/rocm $ENV{HIP_PATH})
@@ -33,23 +23,21 @@
3323
find_package(hip REQUIRED)
3424
find_package(migraphx REQUIRED PATHS ${AMD_MIGRAPHX_HOME})
3525

36-
set(migraphx_libs migraphx::c hip::host)
37-
3826
file(GLOB_RECURSE onnxruntime_providers_migraphx_cc_srcs CONFIGURE_DEPENDS
3927
"${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.h"
4028
"${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.cc"
4129
"${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h"
4230
"${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc"
4331
)
4432
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_migraphx_cc_srcs})
45-
onnxruntime_add_shared_library_module(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs})
33+
onnxruntime_add_shared_library(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs})
4634
onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface)
47-
add_dependencies(onnxruntime_providers_migraphx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES})
48-
target_link_libraries(onnxruntime_providers_migraphx PRIVATE ${migraphx_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface)
49-
target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime)
35+
add_dependencies(onnxruntime_providers_migraphx ${onnxruntime_EXTERNAL_DEPENDENCIES})
36+
target_link_libraries(onnxruntime_providers_migraphx PRIVATE migraphx::c hip::host ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface)
37+
target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/migraphx/onnxruntime)
5038
set_target_properties(onnxruntime_providers_migraphx PROPERTIES LINKER_LANGUAGE CXX)
5139
set_target_properties(onnxruntime_providers_migraphx PROPERTIES FOLDER "ONNXRuntime")
52-
target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1)
40+
target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1 ONNX_ML=1 ONNX_NAMESPACE=onnx)
5341
if(MSVC)
5442
set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS /DEF:${ONNXRUNTIME_ROOT}/core/providers/migraphx/symbols.def)
5543
target_link_libraries(onnxruntime_providers_migraphx PRIVATE ws2_32)
@@ -62,6 +50,15 @@
6250
target_link_libraries(onnxruntime_providers_migraphx PRIVATE stdc++fs)
6351
endif()
6452

53+
set(CMAKE_REQUIRED_LIBRARIES migraphx::c)
54+
55+
check_symbol_exists(migraphx_onnx_options_set_external_data_path
56+
"migraphx/migraphx.h" HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH)
57+
58+
if(HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH)
59+
target_compile_definitions(onnxruntime_providers_migraphx PRIVATE HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH=1)
60+
endif()
61+
6562
if (onnxruntime_ENABLE_TRAINING_OPS)
6663
onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_training)
6764
target_link_libraries(onnxruntime_providers_migraphx PRIVATE onnxruntime_training)
@@ -71,15 +68,39 @@
7168
endif()
7269

7370
if(CMAKE_SYSTEM_NAME STREQUAL "Windows")
74-
install(TARGETS onnxruntime_providers_migraphx
75-
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
76-
LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR}
77-
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
78-
)
79-
else()
80-
install(TARGETS onnxruntime_providers_migraphx
81-
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
82-
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
83-
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
84-
)
71+
foreach(file migraphx-hiprtc-driver.exe migraphx.dll migraphx_c.dll migraphx_cpu.dll migraphx_device.dll migraphx_gpu.dll migraphx_onnx.dll migraphx_tf.dll)
72+
set(_source "${AMD_MIGRAPHX_HOME}/bin/${file}")
73+
if(EXISTS "${_source}")
74+
add_custom_command(TARGET onnxruntime_providers_migraphx
75+
POST_BUILD
76+
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${_source} $<TARGET_FILE_DIR:onnxruntime_providers_migraphx>)
77+
set(_target "$<TARGET_FILE_DIR:onnxruntime_providers_migraphx>/${file}")
78+
list(APPEND _migraphx_targets ${_target})
79+
endif()
80+
endforeach()
81+
set(MIGRAPHX_LIB_FILES ${_migraphx_targets} CACHE INTERNAL "" FORCE)
82+
install(FILES ${_migraphx_targets}
83+
DESTINATION ${CMAKE_INSTALL_BINDIR})
84+
get_property(_amdhip64_location TARGET hip::amdhip64 PROPERTY IMPORTED_LOCATION_RELEASE)
85+
cmake_path(GET _amdhip64_location PARENT_PATH _hipsdk_path)
86+
foreach(file amd_comgr0602.dll amd_comgr0604.dll amd_comgr0700.dll hiprtc0602.dll hiprtc0604.dll hiprtc0700.dll hiprtc-builtins0602.dll hiprtc-builtins0604.dll hiprtc-builtins0700.dll)
87+
set(_source "${_hipsdk_path}/${file}")
88+
if(EXISTS "${_source}")
89+
add_custom_command(TARGET onnxruntime_providers_migraphx
90+
POST_BUILD
91+
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${_source} $<TARGET_FILE_DIR:onnxruntime_providers_migraphx>)
92+
set(_target "$<TARGET_FILE_DIR:onnxruntime_providers_migraphx>/${file}")
93+
list(APPEND _hipsdk_targets ${_target})
94+
endif()
95+
endforeach()
96+
set(HIPSDK_LIB_FILES ${_hipsdk_targets} CACHE INTERNAL "" FORCE)
97+
install(FILES ${_hipsdk_targets}
98+
DESTINATION ${CMAKE_INSTALL_BINDIR})
8599
endif()
100+
101+
install(TARGETS onnxruntime_providers_migraphx
102+
EXPORT onnxruntime_providers_migraphxTargets
103+
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
104+
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
105+
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
106+
FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR})

cmake/onnxruntime_python.cmake

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,21 @@ if (onnxruntime_USE_OPENVINO)
740740
)
741741
endif()
742742

743+
if (onnxruntime_USE_MIGRAPHX)
744+
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
745+
add_custom_command(
746+
TARGET onnxruntime_pybind11_state POST_BUILD
747+
COMMAND ${CMAKE_COMMAND} -E copy
748+
${MIGRAPHX_LIB_FILES}
749+
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/capi/)
750+
add_custom_command(
751+
TARGET onnxruntime_pybind11_state POST_BUILD
752+
COMMAND ${CMAKE_COMMAND} -E copy
753+
${HIPSDK_LIB_FILES}
754+
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/capi/)
755+
endif()
756+
endif()
757+
743758
if (onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS)
744759
add_custom_command(
745760
TARGET onnxruntime_pybind11_state POST_BUILD

cmake/onnxruntime_unittests.cmake

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,6 @@ endif()
610610

611611
if(onnxruntime_USE_MIGRAPHX)
612612
list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_migraphx)
613-
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx onnxruntime_providers_shared)
614613
endif()
615614

616615
if(onnxruntime_USE_COREML)
@@ -691,9 +690,6 @@ endif()
691690

692691
if(onnxruntime_USE_MIGRAPHX)
693692
list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/migraphx/*)
694-
list(APPEND onnxruntime_test_framework_src_patterns "${ONNXRUNTIME_ROOT}/core/providers/migraphx/migraphx_execution_provider_utils.h")
695-
list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_migraphx)
696-
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx onnxruntime_providers_shared)
697693
endif()
698694

699695
if(onnxruntime_USE_NNAPI_BUILTIN)

include/onnxruntime/core/common/common.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,12 +294,26 @@ inline std::string ToUTF8String(const std::string& s) { return s; }
294294
/**
295295
* Convert a wide character string to a UTF-8 string
296296
*/
297-
std::string ToUTF8String(const std::wstring& s);
298-
299-
std::wstring ToWideString(const std::string& s);
297+
std::string ToUTF8String(std::wstring_view s);
298+
inline std::string ToUTF8String(const wchar_t* s) {
299+
return ToUTF8String(std::wstring_view{s});
300+
}
301+
inline std::string ToUTF8String(const std::wstring& s) {
302+
return ToUTF8String(std::wstring_view{s});
303+
}
304+
std::wstring ToWideString(std::string_view s);
305+
inline std::wstring ToWideString(const char* s) {
306+
return ToWideString(std::string_view{s});
307+
}
308+
inline std::wstring ToWideString(const std::string& s) {
309+
return ToWideString(std::string_view{s});
310+
}
300311
inline std::wstring ToWideString(const std::wstring& s) { return s; }
312+
inline std::wstring ToWideString(std::wstring_view s) { return std::wstring{s}; }
301313
#else
302314
inline std::string ToWideString(const std::string& s) { return s; }
315+
inline std::string ToWideString(const char* s) { return s; }
316+
inline std::string ToWideString(std::string_view s) { return std::string{s}; }
303317
#endif
304318

305319
constexpr size_t kMaxStrLen = 4096;

include/onnxruntime/core/common/string_helper.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,9 @@
77
// forward declaration
88
struct OrtAllocator;
99
namespace onnxruntime {
10-
char* StrDup(const std::string& str, OrtAllocator* allocator);
10+
char* StrDup(std::string_view str, OrtAllocator* allocator);
11+
inline char* StrDup(const std::string& str, OrtAllocator* allocator) {
12+
return StrDup(std::string_view{str}, allocator);
13+
}
14+
wchar_t* StrDup(std::wstring_view str, OrtAllocator* allocator);
1115
} // namespace onnxruntime

include/onnxruntime/core/framework/provider_options_utils.h

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,24 @@ class ProviderOptionsParser {
8383
template <typename ValueParserType>
8484
ProviderOptionsParser& AddValueParser(
8585
const std::string& name, ValueParserType value_parser) {
86+
return AddValueParser(std::string_view{name}, value_parser);
87+
}
88+
89+
template <typename ValueParserType>
90+
ProviderOptionsParser& AddValueParser(
91+
std::string_view name, ValueParserType value_parser) {
8692
ORT_ENFORCE(
8793
value_parsers_.emplace(name, ValueParser{value_parser}).second,
8894
"Provider option \"", name, "\" already has a value parser.");
8995
return *this;
9096
}
9197

98+
template <typename ValueParserType>
99+
ProviderOptionsParser& AddValueParser(
100+
const char* name, ValueParserType value_parser) {
101+
return AddValueParser<ValueParserType>(std::string_view{name}, value_parser);
102+
}
103+
92104
/**
93105
* Adds a parser for a particular provider option value which converts a
94106
* value to the right type and assigns it to the given reference.
@@ -104,13 +116,25 @@ class ProviderOptionsParser {
104116
template <typename ValueType>
105117
ProviderOptionsParser& AddAssignmentToReference(
106118
const std::string& name, ValueType& dest) {
119+
return AddAssignmentToReference(std::string_view{name}, dest);
120+
}
121+
122+
template <typename ValueType>
123+
ProviderOptionsParser& AddAssignmentToReference(
124+
std::string_view name, ValueType& dest) {
107125
return AddValueParser(
108126
name,
109-
[&dest](const std::string& value_str) -> Status {
127+
[&dest](std::string_view value_str) -> Status {
110128
return ParseStringWithClassicLocale(value_str, dest);
111129
});
112130
}
113131

132+
template <typename ValueType>
133+
ProviderOptionsParser& AddAssignmentToReference(
134+
const char* name, ValueType& dest) {
135+
return AddAssignmentToReference<ValueType>(std::string_view{name}, dest);
136+
}
137+
114138
/**
115139
* Adds a parser for a particular provider option value which maps an
116140
* enumeration name to a value and assigns it to the given reference.
@@ -128,13 +152,25 @@ class ProviderOptionsParser {
128152
template <typename EnumType>
129153
ProviderOptionsParser& AddAssignmentToEnumReference(
130154
const std::string& name, const EnumNameMapping<EnumType>& mapping, EnumType& dest) {
155+
return AddAssignmentToEnumReference(std::string_view{name}, mapping, dest);
156+
}
157+
158+
template <typename EnumType>
159+
ProviderOptionsParser& AddAssignmentToEnumReference(
160+
std::string_view name, const EnumNameMapping<EnumType>& mapping, EnumType& dest) {
131161
return AddValueParser(
132162
name,
133163
[&mapping, &dest](const std::string& value_str) -> Status {
134164
return NameToEnum(mapping, value_str, dest);
135165
});
136166
}
137167

168+
template <typename EnumType>
169+
ProviderOptionsParser& AddAssignmentToEnumReference(
170+
const char* name, const EnumNameMapping<EnumType>& mapping, EnumType& dest) {
171+
return AddAssignmentToEnumReference<EnumType>(std::string_view{name}, mapping, dest);
172+
}
173+
138174
/**
139175
* Parses the given provider options.
140176
*/

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -754,13 +754,13 @@ typedef struct OrtMIGraphXProviderOptions {
754754
int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true
755755
int migraphx_fp8_enable; // MIGraphX FP8 precision. Default 0 = false, nonzero = true
756756
int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true
757-
int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true
757+
int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, nonzero = true
758758
const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name
759-
int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, noznero = true
759+
int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, nonzero = true
760760
const char* migraphx_save_model_path; // migraphx model path name
761-
int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true
761+
int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, nonzero = true
762762
const char* migraphx_load_model_path; // migraphx model path name
763-
bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false
763+
bool migraphx_exhaustive_tune; // MIGraphX tuned compile. Default = false, nonzero = true
764764

765765
/** \brief MIGraphX memory limit (To use all possible memory pass in maximum size_t)
766766
* Defaults to SIZE_MAX.
@@ -776,6 +776,7 @@ typedef struct OrtMIGraphXProviderOptions {
776776
*/
777777
int migraphx_arena_extend_strategy;
778778

779+
// This is the legacy struct and don't add new fields here.
779780
} OrtMIGraphXProviderOptions;
780781

781782
/** \brief OpenVINO Provider Options

onnxruntime/core/common/helper.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
namespace onnxruntime {
2020
#ifdef _WIN32
21-
std::string ToUTF8String(const std::wstring& s) {
21+
std::string ToUTF8String(std::wstring_view s) {
2222
if (s.size() >= static_cast<size_t>(std::numeric_limits<int>::max()))
2323
ORT_THROW("length overflow");
2424

@@ -33,7 +33,7 @@ std::string ToUTF8String(const std::wstring& s) {
3333
return ret;
3434
}
3535

36-
std::wstring ToWideString(const std::string& s) {
36+
std::wstring ToWideString(std::string_view s) {
3737
if (s.size() >= static_cast<size_t>(std::numeric_limits<int>::max()))
3838
ORT_THROW("length overflow");
3939

0 commit comments

Comments
 (0)