Skip to content

Commit 5dc57b7

Browse files
authored
Merge pull request #12593 from NHZlX/filter_redundant_output
filter redundant output
2 parents 5305282 + 943950c commit 5dc57b7

File tree

4 files changed

+33
-3
lines changed

4 files changed

+33
-3
lines changed

paddle/fluid/inference/analysis/data_flow_graph.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,34 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
337337
std::vector<Node *>(outputs.begin(), outputs.end()));
338338
}
339339

340+
void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
341+
std::vector<Node *> op_nodes;
342+
for (auto &node : GraphTraits<DataFlowGraph>(graph).nodes_in_TS()) {
343+
if (node.type() == Node::Type::kValue || node.deleted()) {
344+
continue;
345+
}
346+
op_nodes.push_back(&node);
347+
}
348+
size_t op_num = op_nodes.size();
349+
for (size_t i = 0; i < op_num; i++) {
350+
if (op_nodes[i]->type() == Node::Type::kFunction) continue;
351+
std::unordered_set<std::string> follow_up_input_names;
352+
for (size_t j = i + 1; j < op_num; j++) {
353+
for (auto *in : op_nodes[j]->inlinks) {
354+
follow_up_input_names.insert(in->name());
355+
}
356+
}
357+
std::vector<Node *> filtered_subgraph_outlinks;
358+
for (auto *out : op_nodes[i]->outlinks) {
359+
if (follow_up_input_names.count(out->name())) {
360+
filtered_subgraph_outlinks.push_back(out);
361+
}
362+
}
363+
PADDLE_ENFORCE_GE(filtered_subgraph_outlinks.size(), 1UL);
364+
op_nodes[i]->outlinks = filtered_subgraph_outlinks;
365+
}
366+
}
367+
340368
} // namespace analysis
341369
} // namespace inference
342370
} // namespace paddle

paddle/fluid/inference/analysis/data_flow_graph.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ struct GraphTraits<DataFlowGraph> {
178178
std::pair<std::vector<Node *>, std::vector<Node *>>
179179
ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph); // NOLINT
180180

181+
void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph);
181182
} // namespace analysis
182183
} // namespace inference
183184
} // namespace paddle

paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ bool DataFlowGraphToFluidPass::Initialize(Argument *argument) {
5252
bool DataFlowGraphToFluidPass::Finalize() { return true; }
5353

5454
void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
55+
FilterRedundantOutputOfSubGraph(graph);
5556
LOG(INFO) << "graph.inputs " << graph->inputs.size();
5657
for (auto &node : GraphTraits<DataFlowGraph>(graph).nodes_in_TS()) {
5758
if (node.deleted()) continue;

paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) {
4646
for (size_t i = 0; i < graph->nodes.size(); i++) {
4747
const Node &node = graph->nodes.Get(i);
4848
if (!config_.display_deleted_node && node.deleted()) continue;
49-
for (auto &in : node.inlinks) {
50-
if (!config_.display_deleted_node && in->deleted()) continue;
51-
dot.AddEdge(in->repr(), node.repr(), {});
49+
for (auto &out : node.outlinks) {
50+
if (!config_.display_deleted_node && out->deleted()) continue;
51+
dot.AddEdge(node.repr(), out->repr(), {});
5252
}
5353
}
5454
return dot.Build();

0 commit comments

Comments
 (0)