File tree Expand file tree Collapse file tree 4 files changed +33
-3
lines changed
paddle/fluid/inference/analysis Expand file tree Collapse file tree 4 files changed +33
-3
lines changed Original file line number Diff line number Diff line change @@ -337,6 +337,34 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
337337 std::vector<Node *>(outputs.begin (), outputs.end ()));
338338}
339339
340+ void FilterRedundantOutputOfSubGraph (DataFlowGraph *graph) {
341+ std::vector<Node *> op_nodes;
342+ for (auto &node : GraphTraits<DataFlowGraph>(graph).nodes_in_TS ()) {
343+ if (node.type () == Node::Type::kValue || node.deleted ()) {
344+ continue ;
345+ }
346+ op_nodes.push_back (&node);
347+ }
348+ size_t op_num = op_nodes.size ();
349+ for (size_t i = 0 ; i < op_num; i++) {
350+ if (op_nodes[i]->type () == Node::Type::kFunction ) continue ;
351+ std::unordered_set<std::string> follow_up_input_names;
352+ for (size_t j = i + 1 ; j < op_num; j++) {
353+ for (auto *in : op_nodes[j]->inlinks ) {
354+ follow_up_input_names.insert (in->name ());
355+ }
356+ }
357+ std::vector<Node *> filtered_subgraph_outlinks;
358+ for (auto *out : op_nodes[i]->outlinks ) {
359+ if (follow_up_input_names.count (out->name ())) {
360+ filtered_subgraph_outlinks.push_back (out);
361+ }
362+ }
363+ PADDLE_ENFORCE_GE (filtered_subgraph_outlinks.size (), 1UL );
364+ op_nodes[i]->outlinks = filtered_subgraph_outlinks;
365+ }
366+ }
367+
340368} // namespace analysis
341369} // namespace inference
342370} // namespace paddle
Original file line number Diff line number Diff line change @@ -178,6 +178,7 @@ struct GraphTraits<DataFlowGraph> {
178178std::pair<std::vector<Node *>, std::vector<Node *>>
179179ExtractInputAndOutputOfSubGraph (std::vector<Node *> &graph); // NOLINT
180180
181+ void FilterRedundantOutputOfSubGraph (DataFlowGraph *graph);
181182} // namespace analysis
182183} // namespace inference
183184} // namespace paddle
Original file line number Diff line number Diff line change @@ -52,6 +52,7 @@ bool DataFlowGraphToFluidPass::Initialize(Argument *argument) {
5252bool DataFlowGraphToFluidPass::Finalize () { return true ; }
5353
5454void DataFlowGraphToFluidPass::Run (DataFlowGraph *graph) {
55+ FilterRedundantOutputOfSubGraph (graph);
5556 LOG (INFO) << " graph.inputs " << graph->inputs .size ();
5657 for (auto &node : GraphTraits<DataFlowGraph>(graph).nodes_in_TS ()) {
5758 if (node.deleted ()) continue ;
Original file line number Diff line number Diff line change @@ -46,9 +46,9 @@ std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) {
4646 for (size_t i = 0 ; i < graph->nodes .size (); i++) {
4747 const Node &node = graph->nodes .Get (i);
4848 if (!config_.display_deleted_node && node.deleted ()) continue ;
49- for (auto &in : node.inlinks ) {
50- if (!config_.display_deleted_node && in ->deleted ()) continue ;
51- dot.AddEdge (in-> repr (), node. repr (), {});
49+ for (auto &out : node.outlinks ) {
50+ if (!config_.display_deleted_node && out ->deleted ()) continue ;
51+ dot.AddEdge (node. repr (), out-> repr (), {});
5252 }
5353 }
5454 return dot.Build ();
You can’t perform that action at this time.
0 commit comments