diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index a59347841be95..0ce7f5a7ce702 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -766,19 +766,37 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st } } } else { - for (const auto& output : node->OutputDefs()) { + const auto& output_defs = node->OutputDefs(); + for (size_t output_idx = 0, end = output_defs.size(); output_idx < end; ++output_idx) { + const auto* output = output_defs[output_idx]; const auto& it = fused_inputs.find(output); if (it != fused_inputs.end()) { fused_inputs.erase(it); erased.insert(output); } - // Only when output is neither in input list nor erased list, add the output to output list - else { - if (erased.find(output) == erased.end()) { - if (std::find(graph_output_names.begin(), - graph_output_names.end(), output->Name()) != graph_output_names.end()) { + // Only when output is neither in input list nor erased list, evaluate whether we should keep it. Only keep it if it has an external consumer, or it is a graph output. + else if (erased.find(output) == erased.end()) { + const bool is_graph_output = std::find(graph_output_names.begin(), + graph_output_names.end(), output->Name()) != graph_output_names.end(); + + bool has_external_consumer = false; + for (auto edge_it = node->OutputEdgesBegin(), edge_end = node->OutputEdgesEnd(); + edge_it != edge_end; ++edge_it) { + if (edge_it->GetSrcArgIndex() == static_cast(output_idx) && + node_set.find(edge_it->GetNode().Index()) == node_set.end()) { + has_external_consumer = true; + break; + } + } + + if (has_external_consumer) { + fused_outputs_to_add[output] = output_order; + if (is_graph_output) { graph_outputs_to_add[output] = output_order; } + ++output_order; + } else if (is_graph_output) { + graph_outputs_to_add[output] = output_order; fused_outputs[output] = output_order++; } } @@ -799,7 +817,7 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st outputs.insert(std::pair(it->second, it->first)); } - // It is possible that an output of an node is put bebind the output of an later + // It is possible that an output of an node is put behind the output of an later // node in the graph output list. So we should sort the output name according // to the graph output names std::vector output_names;