Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -766,19 +766,37 @@ std::unique_ptr<IndexedSubGraph> MIGraphXExecutionProvider::GetSubGraph(const st
}
}
} else {
for (const auto& output : node->OutputDefs()) {
const auto& output_defs = node->OutputDefs();
for (size_t output_idx = 0, end = output_defs.size(); output_idx < end; ++output_idx) {
const auto* output = output_defs[output_idx];
const auto& it = fused_inputs.find(output);
if (it != fused_inputs.end()) {
fused_inputs.erase(it);
erased.insert(output);
}
// Only when output is neither in input list nor erased list, add the output to output list
else {
if (erased.find(output) == erased.end()) {
if (std::find(graph_output_names.begin(),
graph_output_names.end(), output->Name()) != graph_output_names.end()) {
// Only when output is neither in input list nor erased list, evaluate whether we should keep it. Only keep it if it has an external consumer, or it is a graph output.
else if (erased.find(output) == erased.end()) {
const bool is_graph_output = std::find(graph_output_names.begin(),
graph_output_names.end(), output->Name()) != graph_output_names.end();

bool has_external_consumer = false;
for (auto edge_it = node->OutputEdgesBegin(), edge_end = node->OutputEdgesEnd();
edge_it != edge_end; ++edge_it) {
if (edge_it->GetSrcArgIndex() == static_cast<int>(output_idx) &&
node_set.find(edge_it->GetNode().Index()) == node_set.end()) {
has_external_consumer = true;
break;
}
}

if (has_external_consumer) {
fused_outputs_to_add[output] = output_order;
if (is_graph_output) {
graph_outputs_to_add[output] = output_order;
}
++output_order;
} else if (is_graph_output) {
graph_outputs_to_add[output] = output_order;
fused_outputs[output] = output_order++;
}
}
Expand All @@ -799,7 +817,7 @@ std::unique_ptr<IndexedSubGraph> MIGraphXExecutionProvider::GetSubGraph(const st
outputs.insert(std::pair<int, const NodeArg*>(it->second, it->first));
}

// It is possible that an output of an node is put bebind the output of an later
// It is possible that an output of an node is put behind the output of an later
// node in the graph output list. So we should sort the output name according
// to the graph output names
std::vector<std::string> output_names;
Expand Down