Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ static bool IsNodeSupported(const std::set<std::string>& op_set,
return true;
}

std::unique_ptr<IndexedSubGraph> MIGraphXExecutionProvider::GetSubGraph(const std::vector<std::size_t>& graph_nodes_index, const GraphViewer& graph) const {
std::unique_ptr<IndexedSubGraph> MIGraphXExecutionProvider::GetSubGraph(const std::vector<std::size_t>& graph_nodes_index, const GraphViewer& graph, bool is_graph_split) const {
std::unordered_set<size_t> node_set;
node_set.reserve(graph_nodes_index.size());
for (const auto& index : graph_nodes_index) {
Expand Down Expand Up @@ -808,7 +808,13 @@ std::unique_ptr<IndexedSubGraph> MIGraphXExecutionProvider::GetSubGraph(const st
if (output.second->Exists()) {
auto name = output.second->Name();
if (std::find(graph_output_names.begin(), graph_output_names.end(), name) == graph_output_names.end()) {
// if graph is split we dont know if output is used so we need this, otherwise if the graph isn't split
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the graph is split, can't we check if an output is used by checking if it has an external consumer? Wouldn't it be better to build this check into the logic above so that we don't have to add another parameter into the function call and complicate the code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We typically want to run the entire model through MIGraphx instead of adding memcopies between another EP (say CPU EP) since that adds overhead.

Sure the logic can be added later but in the "no partition" case, why bother checking every node? We know the entire graphs input and outputs from the original metadata and then can confidently say those nodes have no use an can be pruned.

For the split case, you can check but again going back to the first piece ( we want to minimize fallback) it would be a signal for us to ensure all the operators in the graph are supported and add the support in MIGraphX for that.

Copy link
Contributor

@Jonahcb Jonahcb Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true. This approach reduces the overhead of looping through the edges of every node.

One concern: this approach will promote internal dangling outputs to subgraph outputs in the "partition" case, which will then be computed and allocated when they should have been pruned in this step, leading to a waste of memory and computation later on.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because the flag for the unsupported nodes will be nonzero, and then we revert back to the previous behavior as if the is_graph_split check wasn't set anymore.

The issue isn't just that we have dangling output edges from internal nodes, its that in the full case they're interpreted as a valid output to be promoted to the graph output, modifying the metadata. In this case too there are 0 fusions being done that we saw in the model.

In the split case, yes you need to check if that node output edge is consumed before determine if its dangling.

Copy link
Contributor

@Jonahcb Jonahcb Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't understand what you mean.

No, because the flag for the unsupported nodes will be nonzero, and then we revert back to the previous behavior > as if the is_graph_split check wasn't set anymore.

By "previous behavior," you mean the behavior before this PR, which promoted dangling outputs from individual nodes to outputs of the full subgraph?

I assume that is what you mean because you say this:

In the split case, yes you need to check if that node output edge is consumed before determine if its dangling.

But then, given that in this PR we are not checking "if that node output edge is consumed before determin[ing] if it's dangling" that means we have incorrect logic in the split case in this PR.

Specifically, is_graph_split will evaluate to true and promote internal dangling outputs of nodes in this subgraph to subgraph outputs. Although these dangling outputs won't be promoted to model outputs, they will still be treated as outputs of the subgraph when they should have been pruned.

The issue isn't just that we have dangling output edges from internal nodes, its that in the full case they're >interpreted as a valid output to be promoted to the graph output, modifying the metadata. In this case too there >are 0 fusions being done that we saw in the model.

I agree that there are two issues:
1). Dangling outputs promoted to full model outputs,
2). dangling outputs promoted to fused node outputs

This solves 1, but keeps problem 2 intact.

A better approach would be to solve 1 and 2 with one modification.

Copy link
Contributor Author

@TedThemistokleous TedThemistokleous Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but this isn't about "better" its about "good enough" here for now. I've got a customer waiting on a change for a fully parsed model so case 2 isn't valid. If we hit case 2 we typically just let MIGraphX do the optimizations since we turn off all optimizations and let MIGraphX handle it as well or add in the missing parser op support. infact similar to other EPs the intent is that case 1 is the default case and support be added so we don't have to split the graph ever.

// then we can safely assume this output is a dangling output from a node and to discard it as part of the
// final graph output
if(is_graph_split)
{
output_names.push_back(name);
}
} else {
graph_out_names.insert(name);
}
Expand Down Expand Up @@ -1085,11 +1091,12 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v
// Example weights, reshape shape etc.
std::unordered_set<std::string> mgx_required_initializers;
const auto unsupported_nodes = GetUnsupportedNodeIndices(graph_viewer, mgx_required_initializers, *GetLogger());
bool is_graph_not_split = unsupported_nodes.empty();

// If all ops are supported, no partitioning is required. Short-circuit and avoid splitting.
if (unsupported_nodes.empty()) {
if (is_graph_not_split) {
auto node_indices = graph_viewer.GetNodesInTopologicalOrder();
auto sub_graph = GetSubGraph(node_indices, graph_viewer);
auto sub_graph = GetSubGraph(node_indices, graph_viewer, !is_graph_not_split);
result.push_back(ComputeCapability::Create(std::move(sub_graph)));
} else { // unsupported_nodes_idx.empty()
if (dump_model_ops_) {
Expand All @@ -1110,7 +1117,7 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v
SubgraphPostProcessing(graph_viewer, mgx_clusters, *GetLogger());

for (const auto& this_cluster : mgx_clusters) {
auto sub_graph = GetSubGraph(this_cluster, graph_viewer);
auto sub_graph = GetSubGraph(this_cluster, graph_viewer, !is_graph_not_split);
result.push_back(ComputeCapability::Create(std::move(sub_graph)));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
std::shared_ptr<KernelRegistry> GetKernelRegistry() const override;
std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const override;

std::unique_ptr<IndexedSubGraph> GetSubGraph(const std::vector<std::size_t>& graph_nodes_index, const GraphViewer& graph) const;
std::unique_ptr<IndexedSubGraph> GetSubGraph(const std::vector<std::size_t>& graph_nodes_index, const GraphViewer& graph, bool is_graph_split) const;
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override;
OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;
std::vector<AllocatorPtr> CreatePreferredAllocators() override;
Expand Down
Loading