Skip to content

Commit 5ec2fb0

Browse files
committed
add flexibledfs for find path between two nodes
1 parent af15f6f commit 5ec2fb0

File tree

3 files changed

+111
-0
lines changed

3 files changed

+111
-0
lines changed

paddle/fluid/inference/analysis/data_flow_graph.cc

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,13 +480,50 @@ void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
480480
for (auto *out : op_nodes[i]->outlinks) {
481481
if (follow_up_input_names.count(out->name())) {
482482
filtered_subgraph_outlinks.push_back(out);
483+
} else {
484+
out->SetDeleted();
483485
}
484486
}
485487
PADDLE_ENFORCE_GE(filtered_subgraph_outlinks.size(), 1UL);
486488
op_nodes[i]->outlinks = filtered_subgraph_outlinks;
487489
}
488490
}
489491

492+
void FlexibleDFS(const std::vector<Node *> &source, bool reverse,
493+
const std::function<bool(const Node *)> &enter,
494+
const std::function<bool(const Node *)> &leave) {
495+
typedef struct {
496+
const Node *node;
497+
bool leave;
498+
} FNode;
499+
std::vector<FNode> stack;
500+
for (auto &node : source) {
501+
stack.push_back(FNode{node, false});
502+
}
503+
std::unordered_set<const Node *> visited;
504+
while (!stack.empty()) {
505+
auto fnode = stack.back();
506+
stack.pop_back();
507+
508+
if (fnode.leave) {
509+
if (leave && !leave(fnode.node)) return;
510+
}
511+
if (visited.count(fnode.node)) continue;
512+
visited.insert(fnode.node);
513+
514+
if (enter && !enter(fnode.node)) return;
515+
516+
if (leave) stack.push_back(FNode{fnode.node, true});
517+
const std::vector<Node *> iter_nodes =
518+
reverse == true ? fnode.node->inlinks : fnode.node->outlinks;
519+
for (const Node *node : iter_nodes) {
520+
if (!visited.count(node)) {
521+
stack.push_back(FNode{node, false});
522+
}
523+
}
524+
}
525+
}
526+
490527
} // namespace analysis
491528
} // namespace inference
492529
} // namespace paddle

paddle/fluid/inference/analysis/data_flow_graph.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ std::pair<std::vector<Node *>, std::vector<Node *>>
204204
ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph); // NOLINT
205205

206206
void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph);
207+
void FlexibleDFS(const std::vector<Node *> &source, bool reverse,
208+
const std::function<bool(const Node *)> &enter,
209+
const std::function<bool(const Node *)> &leave);
207210
} // namespace analysis
208211
} // namespace inference
209212
} // namespace paddle

paddle/fluid/inference/analysis/data_flow_graph_tester.cc

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,77 @@ TEST(DataFlowGraph, Build_IR_Graph) {
160160
ASSERT_EQ(graph.nodes.size(), ir_graph.Nodes().size());
161161
}
162162

163+
// FlexibleDFS
164+
/*
165+
* Graph topology
166+
* inputs: 0
167+
* 0 -> 1
168+
* 1 -> 2
169+
* 1 -> 3
170+
* 3 -> 4
171+
* 4 -> 5
172+
* 5 -> 2
173+
*/
174+
TEST(DataFlowGraph, flexibledfs) {
175+
DataFlowGraph graph;
176+
177+
for (int i = 0; i < 6; i++) {
178+
auto* node = graph.nodes.Create(Node::Type::kValue);
179+
node->SetName("node-" + std::to_string(i));
180+
}
181+
182+
auto add_link = [&](int i, int j) {
183+
Node* source = graph.nodes.GetMutable(i);
184+
Node* target = graph.nodes.GetMutable(j);
185+
target->inlinks.push_back(source);
186+
source->outlinks.push_back(target);
187+
};
188+
189+
add_link(0, 1);
190+
add_link(1, 2);
191+
add_link(1, 3);
192+
add_link(3, 4);
193+
add_link(4, 5);
194+
add_link(5, 2);
195+
graph.Build();
196+
197+
std::vector<const Node*> order;
198+
FlexibleDFS(graph.inputs(), false, nullptr, [&order](const Node* n) {
199+
order.push_back(n);
200+
return true;
201+
});
202+
203+
ASSERT_EQ(order.size(), 6UL);
204+
205+
order.clear();
206+
// reverse dfs
207+
FlexibleDFS(graph.outputs(), true, nullptr, [&order](const Node* n) {
208+
order.push_back(n);
209+
return true;
210+
});
211+
212+
ASSERT_EQ(order.size(), 6UL);
213+
214+
// If we delete
215+
Node* last_node = graph.nodes.GetMutable(2);
216+
Node* direct_node = graph.nodes.GetMutable(1);
217+
std::vector<Node*> source_nodes;
218+
for (Node* node : last_node->inlinks) {
219+
if (node != direct_node) source_nodes.push_back(node);
220+
}
221+
222+
bool has_cycle = false;
223+
FlexibleDFS(source_nodes, true, nullptr,
224+
[&has_cycle, direct_node](const Node* n) {
225+
if (n == direct_node) {
226+
has_cycle = true;
227+
return false;
228+
}
229+
return true;
230+
});
231+
ASSERT_TRUE(has_cycle);
232+
}
233+
163234
} // namespace analysis
164235
} // namespace inference
165236
} // namespace paddle

0 commit comments

Comments
 (0)