Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cmake/onnxruntime_providers_migraphx.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@
target_compile_definitions(onnxruntime_providers_migraphx PRIVATE HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH=1)
endif()

check_symbol_exists(migraphx_get_onnx_operators_size
"migraphx/migraphx.h" HAVE_MIGRAPHX_API_GET_ONNX_OPERATORS)

if(HAVE_MIGRAPHX_API_GET_ONNX_OPERATORS)
target_compile_definitions(onnxruntime_providers_migraphx PRIVATE HAVE_MIGRAPHX_API_GET_ONNX_OPERATORS=1)
endif()



if (onnxruntime_ENABLE_TRAINING_OPS)
onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_training)
target_link_libraries(onnxruntime_providers_migraphx PRIVATE onnxruntime_training)
Expand Down
14 changes: 11 additions & 3 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,15 @@ std::unique_ptr<IndexedSubGraph> MIGraphXExecutionProvider::GetSubGraph(const st

if (target_op_type == "If" || target_op_type == "Loop" || target_op_type == "Scan") {
const auto& src_output_idx = it->GetSrcArgIndex();
if (src_output_idx < node->OutputDefs().size()) {

// Do this to avoid signed to unsigned comparrison here
// if src_output_index is invalid (-1 or less) signal that to be larger than size + 1
// This ensures the check below fails
size_t output_index = 0;
if(src_output_idx < 0)
output_index = node->OutputDefs().size() + 1;

if (output_index < node->OutputDefs().size()) {
const auto* output_def = node->OutputDefs()[src_output_idx];
if (output_def && fused_outputs.find(output_def) == fused_outputs.end() && erased.find(output_def) == erased.end()) {
fused_outputs_to_add[output_def] = output_order++;
Expand Down Expand Up @@ -866,9 +874,9 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer,
/*out*/ std::unordered_set<std::string>& mgx_required_initializers,
const logging::Logger& logger) {

#if HIP_VERSION_MAJOR > 7 || (HIP_VERSION_MAJOR == 7 && HIP_VERSION_MINOR >= 2)
#ifdef HAVE_MIGRAPHX_API_GET_ONNX_OPERATORS
// In ROCm 7.2 onward we'll query the MIGraphX API to get the supported op list
static std::set<std::string> mgx_supported_ops{}
static std::set<std::string> mgx_supported_ops{};
auto list = migraphx::get_onnx_operators();
for(const auto& name : list)
{
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/test/providers/migraphx/migraphx_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ TEST(MIGraphXExecutionProviderTest, canEvalArgument) {
ASSERT_EQ(canEvalNodeArgument(gv, node2, {1}, input_nodes), true);
}

#if defined(WIN32)

static bool SessionHasEp(Ort::Session& session, const char* ep_name) {
// Access the underlying InferenceSession.
const OrtSession* ort_session = session;
Expand All @@ -203,7 +205,6 @@ static bool SessionHasEp(Ort::Session& session, const char* ep_name) {
return has_ep;
}

#if defined(WIN32)
// Tests autoEP feature to automatically select an EP that supports the GPU.
// Currently only works on Windows.
TEST(MIGraphXExecutionProviderTest, AutoEp_PreferGpu) {
Expand Down
Loading