Skip to content

Commit 1153144

Browse files
authored
Inference analysis/init data flow graph analysis (#10776)
Add the demo of subgraph splitter
1 parent a9f9fba commit 1153144

18 files changed

+1321
-76
lines changed
Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,17 @@
1-
cc_library(analysis SRCS dot.cc node.cc node.h)
1+
set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor init)
2+
cc_library(analysis SRCS dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc fluid_to_data_flow_graph_pass.cc
3+
DEPS paddle_fluid)
24
cc_test(test_node SRCS node_tester.cc DEPS analysis)
5+
cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
6+
7+
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)
8+
9+
cc_test(test_data_flow_graph SRCS data_flow_graph_tester.cc DEPS analysis ${FLUID_CORE_MODULES} paddle_fluid
10+
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model)
11+
set_tests_properties(test_data_flow_graph PROPERTIES DEPENDS test_word2vec)
12+
13+
cc_test(test_subgraph_splitter
14+
SRCS subgraph_splitter_tester.cc
15+
DEPS analysis paddle_fluid tensor
16+
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model)
17+
set_tests_properties(test_subgraph_splitter PROPERTIES DEPENDS test_word2vec)
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
16+
#include "paddle/fluid/inference/analysis/dot.h"
17+
18+
namespace paddle {
19+
namespace inference {
20+
namespace analysis {
21+
22+
// It is a better idea that the inputs and outputs of this graph is set manully
23+
// before, but there must be a Pass that helps to prune the unnecessary ops that
24+
// do not contribute to the given targets, so in this pass, analysis and get the
25+
// inputs and outputs is OK.
26+
void DataFlowGraph::Build() {
27+
inputs.clear();
28+
outputs.clear();
29+
std::unordered_set<Node *> ins;
30+
std::unordered_set<Node *> outs;
31+
for (auto &node : nodes.nodes()) {
32+
for (auto *in : node->inlinks) {
33+
ins.insert(in);
34+
}
35+
for (auto *out : node->outlinks) {
36+
outs.insert(out);
37+
}
38+
}
39+
40+
// The nodes that in ins but not in outs is the graph's inputs
41+
// similarly, the nodes that in outs but not in ins is the graphs' outputs
42+
for (auto *in : ins) {
43+
if (!outs.count(in)) {
44+
inputs.push_back(in);
45+
}
46+
}
47+
for (auto *out : outs) {
48+
if (!outs.count(out)) {
49+
outputs.push_back(out);
50+
}
51+
}
52+
}
53+
54+
std::string DataFlowGraph::DotString() const {
55+
Dot dot;
56+
57+
// Add nodes
58+
for (size_t i = 0; i < nodes.size(); i++) {
59+
const Node &node = nodes.Get(i);
60+
switch (node.type()) {
61+
case Node::Type::kValue:
62+
dot.AddNode(node.repr(), node.dot_attrs());
63+
break;
64+
case Node::Type::kFunction:
65+
dot.AddNode(node.repr(), node.dot_attrs());
66+
break;
67+
case Node::Type::kFunctionBlock:
68+
dot.AddNode(node.repr(), node.dot_attrs());
69+
break;
70+
default:
71+
PADDLE_THROW("unsupported Node type %d", static_cast<int>(node.type()));
72+
}
73+
}
74+
75+
// Add edges
76+
for (size_t i = 0; i < nodes.size(); i++) {
77+
const Node &node = nodes.Get(i);
78+
for (auto &in : node.inlinks) {
79+
dot.AddEdge(in->repr(), node.repr(), {});
80+
}
81+
}
82+
return dot.Build();
83+
}
84+
85+
//
86+
// NodesBFSIterator
87+
//
88+
89+
GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator(
90+
const std::vector<Node *> &source)
91+
: queue_(source.begin(), source.end()) {}
92+
93+
// GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator(
94+
// GraphTraits<DataFlowGraph>::NodesBFSIterator &&other) noexcept
95+
// : queue_(std::move(other.queue_)),
96+
// visited_(std::move(other.visited_)) {}
97+
98+
GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator(
99+
const GraphTraits<DataFlowGraph>::NodesBFSIterator &other)
100+
: queue_(other.queue_), visited_(other.visited_) {}
101+
102+
Node &GraphTraits<DataFlowGraph>::NodesBFSIterator::operator*() {
103+
PADDLE_ENFORCE(!queue_.empty());
104+
return *queue_.front();
105+
}
106+
107+
Node *GraphTraits<DataFlowGraph>::NodesBFSIterator::operator->() {
108+
PADDLE_ENFORCE(!queue_.empty());
109+
return queue_.front();
110+
}
111+
112+
GraphTraits<DataFlowGraph>::NodesBFSIterator &
113+
GraphTraits<DataFlowGraph>::NodesBFSIterator::operator=(
114+
const GraphTraits<DataFlowGraph>::NodesBFSIterator &other) {
115+
queue_ = other.queue_;
116+
visited_ = other.visited_;
117+
return *this;
118+
}
119+
120+
GraphTraits<DataFlowGraph>::NodesBFSIterator
121+
&GraphTraits<DataFlowGraph>::NodesBFSIterator::operator++() {
122+
PADDLE_ENFORCE(!queue_.empty());
123+
auto *cur = queue_.front();
124+
visited_.insert(cur);
125+
queue_.pop_front();
126+
for (auto *output : cur->outlinks) {
127+
if (!visited_.count(output)) {
128+
queue_.push_back(output);
129+
visited_.insert(output);
130+
}
131+
}
132+
return *this;
133+
}
134+
135+
bool GraphTraits<DataFlowGraph>::NodesBFSIterator::operator==(
136+
const GraphTraits<DataFlowGraph>::NodesBFSIterator &other) {
137+
if (queue_.empty()) return other.queue_.empty();
138+
if ((!queue_.empty()) && (!other.queue_.empty())) {
139+
return queue_.front() == other.queue_.front() &&
140+
visited_.size() == other.visited_.size(); // here need to check the
141+
// equality of queue and
142+
// visited. Just a light but week implementation.
143+
}
144+
return false;
145+
}
146+
147+
//
148+
// NodesDFSIterator
149+
//
150+
GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator(
151+
const std::vector<Node *> &source) {
152+
for (auto *x : source) stack_.push(x);
153+
}
154+
155+
// GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator(
156+
// GraphTraits<DataFlowGraph>::NodesDFSIterator &&other) noexcept
157+
// : stack_(std::move(other.stack_)),
158+
// visited_(std::move(other.visited_)) {}
159+
160+
GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator(
161+
const GraphTraits<DataFlowGraph>::NodesDFSIterator &other)
162+
: stack_(other.stack_), visited_(other.visited_) {}
163+
164+
Node &GraphTraits<DataFlowGraph>::NodesDFSIterator::operator*() {
165+
PADDLE_ENFORCE(!stack_.empty());
166+
return *stack_.top();
167+
}
168+
169+
GraphTraits<DataFlowGraph>::NodesDFSIterator
170+
&GraphTraits<DataFlowGraph>::NodesDFSIterator::operator++() {
171+
if (stack_.empty()) return *this;
172+
visited_.insert(stack_.top());
173+
auto *cur = stack_.top();
174+
stack_.pop();
175+
for (auto *x : cur->outlinks) {
176+
if (!visited_.count(x)) {
177+
stack_.push(x);
178+
visited_.insert(x);
179+
}
180+
}
181+
return *this;
182+
}
183+
bool GraphTraits<DataFlowGraph>::NodesDFSIterator::operator==(
184+
const GraphTraits<DataFlowGraph>::NodesDFSIterator &other) {
185+
if (stack_.empty()) return other.stack_.empty();
186+
if ((!stack_.empty()) && (!other.stack_.empty())) {
187+
return stack_.top() == other.stack_.top();
188+
}
189+
return false;
190+
}
191+
192+
GraphTraits<DataFlowGraph>::NodesDFSIterator &
193+
GraphTraits<DataFlowGraph>::NodesDFSIterator::operator=(
194+
const GraphTraits<DataFlowGraph>::NodesDFSIterator &other) {
195+
stack_ = other.stack_;
196+
visited_ = other.visited_;
197+
return *this;
198+
}
199+
Node *GraphTraits<DataFlowGraph>::NodesDFSIterator::operator->() {
200+
return stack_.top();
201+
}
202+
203+
} // namespace analysis
204+
} // namespace inference
205+
} // namespace paddle
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
/*
16+
* Data flow graph is an pass that build the basic graph. It contains a graph
17+
* and the iterators that enable the iteration over the graph.
18+
*/
19+
20+
#pragma once
21+
22+
#include <deque>
23+
#include <stack>
24+
#include <unordered_set>
25+
26+
#include "paddle/fluid/inference/analysis/graph_traits.h"
27+
#include "paddle/fluid/inference/analysis/node.h"
28+
#include "paddle/fluid/platform/enforce.h"
29+
30+
namespace paddle {
31+
namespace inference {
32+
namespace analysis {
33+
34+
/*
35+
* DataFlowGraph - A container of Value and Function Nodes.
36+
*/
37+
struct DataFlowGraph {
38+
NodeMap nodes;
39+
std::vector<Node *> inputs;
40+
std::vector<Node *> outputs;
41+
42+
// Extract inputs and outputs of the graph.
43+
void Build();
44+
45+
// Output a DOT graph file for debug.
46+
std::string DotString() const;
47+
};
48+
49+
/*
50+
* An graph trait help to traverse the graph using BFS.
51+
* The BFS start from a graph's inputs, the graph should be fully-connected, so
52+
* that the iterator can reach the end.
53+
*/
54+
template <>
55+
struct GraphTraits<DataFlowGraph> {
56+
// BFS iterator on nodes.
57+
struct NodesBFSIterator
58+
: public std::iterator<std::forward_iterator_tag, Node *> {
59+
NodesBFSIterator() = default;
60+
explicit NodesBFSIterator(const std::vector<Node *> &source);
61+
// NodesBFSIterator(NodesBFSIterator &&other) noexcept;
62+
// NOTE Heavy to use.
63+
NodesBFSIterator(const NodesBFSIterator &other);
64+
65+
Node &operator*();
66+
NodesBFSIterator &operator++();
67+
Node *operator->();
68+
// TODO(Superjomn) current implementation just compare the first
69+
// element, need to compare the graph and all the elements in the queue and
70+
// set.
71+
NodesBFSIterator &operator=(const NodesBFSIterator &other);
72+
bool operator==(const NodesBFSIterator &other);
73+
bool operator!=(const NodesBFSIterator &other) { return !(*this == other); }
74+
75+
private:
76+
std::deque<Node *> queue_;
77+
std::unordered_set<Node *> visited_;
78+
};
79+
80+
// DFS iterator on nodes.
81+
struct NodesDFSIterator
82+
: public std::iterator<std::forward_iterator_tag, Node *> {
83+
NodesDFSIterator() = default;
84+
explicit NodesDFSIterator(const std::vector<Node *> &source);
85+
// NodesDFSIterator(NodesDFSIterator &&other) noexcept;
86+
NodesDFSIterator(const NodesDFSIterator &other);
87+
88+
Node &operator*();
89+
NodesDFSIterator &operator++();
90+
// TODO(Superjomn) current implementation just compare the first
91+
// element, need to compare the graph and all the elements in the queue and
92+
// set.
93+
NodesDFSIterator &operator=(const NodesDFSIterator &other);
94+
bool operator==(const NodesDFSIterator &other);
95+
bool operator!=(const NodesDFSIterator &other) { return !(*this == other); }
96+
Node *operator->();
97+
98+
private:
99+
std::stack<Node *> stack_;
100+
std::unordered_set<Node *> visited_;
101+
};
102+
103+
explicit GraphTraits(DataFlowGraph *graph) : graph_(graph) {}
104+
105+
// default use BFS to visit the nodes.
106+
iterator_range<NodesBFSIterator> nodes() {
107+
return iterator_range<NodesBFSIterator>(nodes_bfs_begin(), nodes_bfs_end());
108+
}
109+
iterator_range<NodesBFSIterator> nodes_in_BFS() {
110+
return iterator_range<NodesBFSIterator>(nodes_bfs_begin(), nodes_bfs_end());
111+
}
112+
iterator_range<NodesDFSIterator> nodes_in_DFS() {
113+
return iterator_range<NodesDFSIterator>(nodes_dfs_begin(), nodes_dfs_end());
114+
}
115+
116+
private:
117+
NodesBFSIterator nodes_bfs_begin() {
118+
return NodesBFSIterator(graph_->inputs);
119+
}
120+
NodesBFSIterator nodes_bfs_end() { return NodesBFSIterator(); }
121+
NodesDFSIterator nodes_dfs_begin() {
122+
return NodesDFSIterator(graph_->inputs);
123+
}
124+
NodesDFSIterator nodes_dfs_end() { return NodesDFSIterator(); }
125+
126+
private:
127+
DataFlowGraph *graph_;
128+
};
129+
130+
// Extract the inputs and outputs of a graph. The inputs and outputs of a
131+
// sub-graph is the inputs nodes and output nodes that doesn't inside the
132+
// sub-graph.
133+
std::pair<
134+
std::vector<Node *>,
135+
std::vector<
136+
Node *>> static ExtractInputAndOutputOfSubGraph(std::vector<Node *>
137+
&graph) {
138+
std::unordered_set<Node *> nodes(graph.begin(), graph.end());
139+
std::unordered_set<Node *> inputs;
140+
std::unordered_set<Node *> outputs;
141+
for (auto &node : graph) {
142+
for (auto *in : node->inlinks) {
143+
if (!nodes.count(in) && in->type() == Node::Type::kValue) {
144+
inputs.insert(in);
145+
}
146+
}
147+
for (auto *out : node->outlinks) {
148+
if (!nodes.count(out) && out->type() == Node::Type::kValue) {
149+
outputs.insert(out);
150+
}
151+
}
152+
}
153+
return std::make_pair(std::vector<Node *>(inputs.begin(), inputs.end()),
154+
std::vector<Node *>(outputs.begin(), outputs.end()));
155+
}
156+
157+
} // namespace analysis
158+
} // namespace inference
159+
} // namespace paddle

0 commit comments

Comments
 (0)