Skip to content

Commit 0cefb94

Browse files
authored
add topological sortting (#12059)
1 parent f920244 commit 0cefb94

File tree

3 files changed

+188
-3
lines changed

3 files changed

+188
-3
lines changed

paddle/fluid/inference/analysis/data_flow_graph.cc

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,20 @@ std::string DataFlowGraph::DotString() const {
9090
return dot.Build();
9191
}
9292

93+
std::string DataFlowGraph::HumanReadableInfo(bool show_values,
94+
bool show_functions) const {
95+
std::stringstream values, functions;
96+
for (auto &n : nodes.nodes()) {
97+
if (show_values && n->IsValue()) {
98+
values << n->repr() << "\n";
99+
}
100+
if (show_functions && n->IsFunction()) {
101+
functions << n->repr() << "\n";
102+
}
103+
}
104+
return "Values:\n" + values.str() + "\n\n" + "Functions:\n" + functions.str();
105+
}
106+
93107
//
94108
// NodesBFSIterator
95109
//
@@ -146,7 +160,7 @@ bool GraphTraits<DataFlowGraph>::NodesBFSIterator::operator==(
146160
if ((!queue_.empty()) && (!other.queue_.empty())) {
147161
return queue_.front() == other.queue_.front() &&
148162
visited_.size() == other.visited_.size(); // here need to check the
149-
// equality of queue and
163+
// equality of queue and
150164
// visited. Just a light but week implementation.
151165
}
152166
return false;
@@ -208,6 +222,76 @@ Node *GraphTraits<DataFlowGraph>::NodesDFSIterator::operator->() {
208222
return stack_.top();
209223
}
210224

225+
GraphTraits<DataFlowGraph>::NodesTSIterator::NodesTSIterator(
226+
const std::vector<Node *> &source) {
227+
PADDLE_ENFORCE(!source.empty(),
228+
"Start points of topological sorting should not be empty!");
229+
std::unordered_set<Node *> visited;
230+
std::unordered_set<Node *> to_visit{source.begin(), source.end()};
231+
232+
std::vector<Node *> inlink_visited;
233+
while (!to_visit.empty()) {
234+
std::vector<Node *> queue(to_visit.begin(), to_visit.end());
235+
for (auto *p : queue) {
236+
inlink_visited.clear();
237+
238+
std::copy_if(p->inlinks.begin(), p->inlinks.end(),
239+
std::back_inserter(inlink_visited),
240+
[&](Node *x) { return visited.count(x); });
241+
242+
if (inlink_visited.size() == p->inlinks.size()) {
243+
sorted_.push_back(p);
244+
for (auto *_ : p->outlinks) {
245+
if (!visited.count(_)) {
246+
to_visit.insert(_);
247+
}
248+
}
249+
250+
to_visit.erase(p);
251+
visited.insert(p);
252+
}
253+
}
254+
}
255+
}
256+
257+
GraphTraits<DataFlowGraph>::NodesTSIterator::NodesTSIterator(
258+
const paddle::inference::analysis::GraphTraits<
259+
DataFlowGraph>::NodesTSIterator &other)
260+
: sorted_(other.sorted_), cursor_(other.cursor_) {}
261+
262+
Node &GraphTraits<DataFlowGraph>::NodesTSIterator::operator*() {
263+
PADDLE_ENFORCE_LT(cursor_, sorted_.size());
264+
return *sorted_[cursor_];
265+
}
266+
267+
paddle::inference::analysis::GraphTraits<DataFlowGraph>::NodesTSIterator
268+
&GraphTraits<DataFlowGraph>::NodesTSIterator::operator++() {
269+
if (++cursor_ >= sorted_.size()) {
270+
sorted_.clear();
271+
cursor_ = 0;
272+
}
273+
return *this;
274+
}
275+
paddle::inference::analysis::GraphTraits<DataFlowGraph>::NodesTSIterator &
276+
GraphTraits<DataFlowGraph>::NodesTSIterator::operator=(
277+
const paddle::inference::analysis::GraphTraits<
278+
DataFlowGraph>::NodesTSIterator &other) {
279+
cursor_ = other.cursor_;
280+
sorted_ = other.sorted_;
281+
return *this;
282+
}
283+
284+
bool GraphTraits<DataFlowGraph>::NodesTSIterator::operator==(
285+
const paddle::inference::analysis::GraphTraits<
286+
DataFlowGraph>::NodesTSIterator &other) {
287+
return sorted_ == other.sorted_ && cursor_ == other.cursor_;
288+
}
289+
290+
Node *GraphTraits<DataFlowGraph>::NodesTSIterator::operator->() {
291+
PADDLE_ENFORCE_LT(cursor_, sorted_.size());
292+
return sorted_[cursor_];
293+
}
294+
211295
} // namespace analysis
212296
} // namespace inference
213297
} // namespace paddle

paddle/fluid/inference/analysis/data_flow_graph.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ struct DataFlowGraph {
4848
// Output a DOT graph file for debug.
4949
std::string DotString() const;
5050

51+
std::string HumanReadableInfo(bool show_values = true,
52+
bool show_functions = true) const;
53+
5154
private:
5255
// Remove duplicate edges and so on.
5356
void Clean();
@@ -107,6 +110,32 @@ struct GraphTraits<DataFlowGraph> {
107110
std::unordered_set<Node *> visited_;
108111
};
109112

113+
// Topological sorting iterator on nodes.
114+
struct NodesTSIterator
115+
: public std::iterator<std::forward_iterator_tag, Node *> {
116+
NodesTSIterator() = default;
117+
explicit NodesTSIterator(const std::vector<Node *> &source);
118+
NodesTSIterator(NodesTSIterator &&other)
119+
: sorted_(std::move(other.sorted_)), cursor_(other.cursor_) {
120+
other.cursor_ = 0;
121+
}
122+
NodesTSIterator(const NodesTSIterator &other);
123+
124+
Node &operator*();
125+
NodesTSIterator &operator++();
126+
// TODO(Superjomn) current implementation just compare the first
127+
// element, need to compare the graph and all the elements in the queue and
128+
// set.
129+
NodesTSIterator &operator=(const NodesTSIterator &other);
130+
bool operator==(const NodesTSIterator &other);
131+
bool operator!=(const NodesTSIterator &other) { return !(*this == other); }
132+
Node *operator->();
133+
134+
private:
135+
std::vector<Node *> sorted_;
136+
int cursor_{0};
137+
};
138+
110139
explicit GraphTraits(DataFlowGraph *graph) : graph_(graph) {}
111140

112141
// default use BFS to visit the nodes.
@@ -119,17 +148,24 @@ struct GraphTraits<DataFlowGraph> {
119148
iterator_range<NodesDFSIterator> nodes_in_DFS() {
120149
return iterator_range<NodesDFSIterator>(nodes_dfs_begin(), nodes_dfs_end());
121150
}
151+
iterator_range<NodesTSIterator> nodes_in_TS() {
152+
return iterator_range<NodesTSIterator>(nodes_ts_begin(), nodes_ts_end());
153+
}
122154

123155
private:
124156
NodesBFSIterator nodes_bfs_begin() {
125157
return NodesBFSIterator(graph_->inputs);
126158
}
127159
NodesBFSIterator nodes_bfs_end() { return NodesBFSIterator(); }
160+
128161
NodesDFSIterator nodes_dfs_begin() {
129162
return NodesDFSIterator(graph_->inputs);
130163
}
131164
NodesDFSIterator nodes_dfs_end() { return NodesDFSIterator(); }
132165

166+
NodesTSIterator nodes_ts_begin() { return NodesTSIterator(graph_->inputs); }
167+
NodesTSIterator nodes_ts_end() { return NodesTSIterator(); }
168+
133169
private:
134170
DataFlowGraph *graph_;
135171
};

paddle/fluid/inference/analysis/data_flow_graph_tester.cc

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ TEST(DataFlowGraph, BFS) {
2424
auto dfg = ProgramDescToDFG(desc);
2525
dfg.Build();
2626

27-
for (auto* in : dfg.inputs) {
27+
for (auto *in : dfg.inputs) {
2828
LOG(INFO) << "inputs: " << in->name() << " "
2929
<< static_cast<int>(in->type());
3030
}
31-
for (auto* out : dfg.outputs) {
31+
for (auto *out : dfg.outputs) {
3232
LOG(INFO) << "outputs: " << out->name() << " "
3333
<< static_cast<int>(out->type());
3434
}
@@ -57,6 +57,71 @@ TEST(DataFlowGraph, DFS) {
5757
ASSERT_EQ(count, dfg.nodes.size());
5858
}
5959

60+
// Topological sorting.
61+
/*
62+
* Graph topology
63+
* inputs: 0, 1, 2
64+
* 0 -> 4
65+
* 0 -> 5
66+
* 1 -> 6
67+
* 2 -> 7
68+
* 4 -> 5
69+
* 4 -> 7
70+
* 4 -> 3
71+
* 7 -> 3
72+
*/
73+
TEST(DataFlowGraph, TS) {
74+
DataFlowGraph graph;
75+
76+
for (int i = 0; i < 8; i++) {
77+
auto *node = graph.nodes.Create(Node::Type::kValue);
78+
node->SetName("node-" + std::to_string(i));
79+
}
80+
81+
auto add_link = [&](int i, int j) {
82+
Node *source = graph.nodes.GetMutable(i);
83+
Node *target = graph.nodes.GetMutable(j);
84+
target->inlinks.push_back(source);
85+
source->outlinks.push_back(target);
86+
};
87+
88+
graph.inputs.push_back(graph.nodes.GetMutable(0));
89+
graph.inputs.push_back(graph.nodes.GetMutable(1));
90+
graph.inputs.push_back(graph.nodes.GetMutable(2));
91+
92+
add_link(0, 4);
93+
add_link(0, 5);
94+
add_link(1, 6);
95+
add_link(2, 7);
96+
add_link(4, 5);
97+
add_link(4, 7);
98+
add_link(4, 3);
99+
add_link(7, 3);
100+
101+
auto its = GraphTraits<DataFlowGraph>(&graph).nodes_in_TS();
102+
std::vector<int> sorted_ids;
103+
for (auto it = its.begin(); it != its.end(); ++it) {
104+
LOG(INFO) << it->name();
105+
sorted_ids.push_back(it->id());
106+
}
107+
108+
// Assert a occurs prior to b in the sorted_ids.
109+
auto assert_positive_sequence_pair = [&](int a, int b) {
110+
auto a_offset = std::find(sorted_ids.begin(), sorted_ids.end(), a);
111+
auto b_offset = std::find(sorted_ids.begin(), sorted_ids.end(), b);
112+
ASSERT_LT(a_offset, b_offset);
113+
};
114+
115+
assert_positive_sequence_pair(2, 7);
116+
assert_positive_sequence_pair(7, 3);
117+
assert_positive_sequence_pair(4, 3);
118+
assert_positive_sequence_pair(0, 4);
119+
assert_positive_sequence_pair(0, 5);
120+
assert_positive_sequence_pair(1, 6);
121+
assert_positive_sequence_pair(4, 5);
122+
assert_positive_sequence_pair(4, 7);
123+
}
124+
60125
} // namespace analysis
61126
} // namespace inference
62127
} // namespace paddle

0 commit comments

Comments
 (0)