diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index a59347841be95..1fb2c2c5f1663 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -694,7 +694,7 @@ static bool IsNodeSupported(const std::set& op_set, return true; } -std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const std::vector& graph_nodes_index, const GraphViewer& graph) const { +std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const std::vector& graph_nodes_index, const GraphViewer& graph, bool is_graph_split) const { std::unordered_set node_set; node_set.reserve(graph_nodes_index.size()); for (const auto& index : graph_nodes_index) { @@ -808,7 +808,13 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st if (output.second->Exists()) { auto name = output.second->Name(); if (std::find(graph_output_names.begin(), graph_output_names.end(), name) == graph_output_names.end()) { + // if graph is split we dont know if output is used so we need this, otherwise if the graph isn't split + // then we can safely assume this output is a dangling output from a node and to discard it as part of the + // final graph output + if(is_graph_split) + { output_names.push_back(name); + } } else { graph_out_names.insert(name); } @@ -1085,11 +1091,12 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v // Example weights, reshape shape etc. std::unordered_set mgx_required_initializers; const auto unsupported_nodes = GetUnsupportedNodeIndices(graph_viewer, mgx_required_initializers, *GetLogger()); + bool is_graph_not_split = unsupported_nodes.empty(); // If all ops are supported, no partitioning is required. Short-circuit and avoid splitting. - if (unsupported_nodes.empty()) { + if (is_graph_not_split) { auto node_indices = graph_viewer.GetNodesInTopologicalOrder(); - auto sub_graph = GetSubGraph(node_indices, graph_viewer); + auto sub_graph = GetSubGraph(node_indices, graph_viewer, !is_graph_not_split); result.push_back(ComputeCapability::Create(std::move(sub_graph))); } else { // unsupported_nodes_idx.empty() if (dump_model_ops_) { @@ -1110,7 +1117,7 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v SubgraphPostProcessing(graph_viewer, mgx_clusters, *GetLogger()); for (const auto& this_cluster : mgx_clusters) { - auto sub_graph = GetSubGraph(this_cluster, graph_viewer); + auto sub_graph = GetSubGraph(this_cluster, graph_viewer, !is_graph_not_split); result.push_back(ComputeCapability::Create(std::move(sub_graph))); } } diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 99f790b9f9f7a..ea0c2f7d9a060 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -82,7 +82,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { std::shared_ptr GetKernelRegistry() const override; std::unique_ptr GetDataTransfer() const override; - std::unique_ptr GetSubGraph(const std::vector& graph_nodes_index, const GraphViewer& graph) const; + std::unique_ptr GetSubGraph(const std::vector& graph_nodes_index, const GraphViewer& graph, bool is_graph_split) const; void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; std::vector CreatePreferredAllocators() override;