Skip to content

Commit 03ff4f6

Browse files
committed
fix subgraph bug!
1 parent 5ec2fb0 commit 03ff4f6

File tree

6 files changed

+215
-60
lines changed

6 files changed

+215
-60
lines changed

paddle/fluid/inference/analysis/data_flow_graph.cc

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
440440
}
441441
return false;
442442
};
443+
443444
for (auto &node : graph) {
444445
for (auto *in : node->inlinks) {
445446
// The Value that is written by nodes inside a sub-graph shouldn't be the
@@ -459,6 +460,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
459460
std::vector<Node *>(outputs.begin(), outputs.end()));
460461
}
461462

463+
// Filter the Intermediate results of the subgraph node.
462464
void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
463465
std::vector<Node *> op_nodes;
464466
for (auto &node : GraphTraits<DataFlowGraph>(*graph).nodes_in_TS()) {
@@ -484,46 +486,11 @@ void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
484486
out->SetDeleted();
485487
}
486488
}
487-
PADDLE_ENFORCE_GE(filtered_subgraph_outlinks.size(), 1UL);
489+
// The filtered_subgraph_outlinks may be empty.
488490
op_nodes[i]->outlinks = filtered_subgraph_outlinks;
489491
}
490492
}
491493

492-
void FlexibleDFS(const std::vector<Node *> &source, bool reverse,
493-
const std::function<bool(const Node *)> &enter,
494-
const std::function<bool(const Node *)> &leave) {
495-
typedef struct {
496-
const Node *node;
497-
bool leave;
498-
} FNode;
499-
std::vector<FNode> stack;
500-
for (auto &node : source) {
501-
stack.push_back(FNode{node, false});
502-
}
503-
std::unordered_set<const Node *> visited;
504-
while (!stack.empty()) {
505-
auto fnode = stack.back();
506-
stack.pop_back();
507-
508-
if (fnode.leave) {
509-
if (leave && !leave(fnode.node)) return;
510-
}
511-
if (visited.count(fnode.node)) continue;
512-
visited.insert(fnode.node);
513-
514-
if (enter && !enter(fnode.node)) return;
515-
516-
if (leave) stack.push_back(FNode{fnode.node, true});
517-
const std::vector<Node *> iter_nodes =
518-
reverse == true ? fnode.node->inlinks : fnode.node->outlinks;
519-
for (const Node *node : iter_nodes) {
520-
if (!visited.count(node)) {
521-
stack.push_back(FNode{node, false});
522-
}
523-
}
524-
}
525-
}
526-
527494
} // namespace analysis
528495
} // namespace inference
529496
} // namespace paddle

paddle/fluid/inference/analysis/data_flow_graph.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,6 @@ std::pair<std::vector<Node *>, std::vector<Node *>>
204204
ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph); // NOLINT
205205

206206
void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph);
207-
void FlexibleDFS(const std::vector<Node *> &source, bool reverse,
208-
const std::function<bool(const Node *)> &enter,
209-
const std::function<bool(const Node *)> &leave);
210207
} // namespace analysis
211208
} // namespace inference
212209
} // namespace paddle

paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,20 +106,23 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
106106

107107
// collect inputs
108108
std::unordered_set<std::string> input_names;
109+
std::unordered_set<std::string> input_names_with_id;
109110
for (auto *x : func->inlinks) {
110111
input_names.insert(x->name());
112+
input_names_with_id.insert(x->name() + std::to_string(x->id()));
111113
}
112114
desc.SetInput(
113115
"Xs", std::vector<std::string>(input_names.begin(), input_names.end()));
114116

115117
std::unordered_set<std::string> output_names;
118+
std::unordered_set<std::string> output_names_with_id;
116119
for (auto *x : func->outlinks) {
117120
output_names.insert(x->name());
121+
output_names_with_id.insert(x->name() + std::to_string(x->id()));
118122
}
119123

120-
std::vector<std::string> output_temp(output_names.begin(),
121-
output_names.end());
122-
desc.SetOutput("Ys", output_temp);
124+
desc.SetOutput(
125+
"Ys", std::vector<std::string>(output_names.begin(), output_names.end()));
123126
desc.SetType("tensorrt_engine");
124127

125128
std::unordered_map<std::string, std::string> output_name_map;
@@ -153,11 +156,12 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
153156
std::vector<std::string> replaced_names;
154157
for (int k = 0; k < in_var->arguments_size(); k++) {
155158
std::string arg_value = in_var->arguments(k);
156-
if (input_names.count(arg_value)) {
159+
std::string arg_value_with_id =
160+
arg_value + std::to_string(var2id[arg_value]);
161+
if (input_names_with_id.count(arg_value_with_id)) {
157162
replaced_names.push_back(arg_value);
158163
} else {
159-
replaced_names.push_back(arg_value +
160-
std::to_string(var2id[arg_value]));
164+
replaced_names.push_back(arg_value_with_id);
161165
}
162166
}
163167
in_var->clear_arguments();
@@ -176,11 +180,12 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
176180
std::vector<std::string> replaced_names;
177181
for (int k = 0; k < out_var->arguments_size(); k++) {
178182
std::string arg_value = out_var->arguments(k);
179-
if (output_names.count(arg_value)) {
180-
output_name_map[arg_value] =
181-
arg_value + std::to_string(var2id[arg_value]);
183+
std::string arg_value_with_id =
184+
arg_value + std::to_string(var2id[arg_value]);
185+
if (output_names_with_id.count(arg_value_with_id)) {
186+
output_name_map[arg_value] = arg_value_with_id;
182187
}
183-
replaced_names.push_back(arg_value + std::to_string(var2id[arg_value]));
188+
replaced_names.push_back(arg_value_with_id);
184189
}
185190
out_var->clear_arguments();
186191
for (size_t k = 0; k < replaced_names.size(); k++) {

paddle/fluid/inference/analysis/subgraph_splitter.cc

Lines changed: 181 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,25 +74,200 @@ void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) {
7474
node_map.at(b)->attr(kUnionFindParent).Int32() = a_ancestor;
7575
}
7676

77+
// This is a simple representation of a graph.
78+
// The BriefNode hold the pointer of the Node.
79+
// This is to avoid changing the original graph
80+
// in the process of trt graph analysis.
81+
struct BriefNode {
82+
explicit BriefNode(Node *n) { node = n; }
83+
Node *node;
84+
std::vector<BriefNode *> inlinks;
85+
std::vector<BriefNode *> outlinks;
86+
};
87+
88+
void UnionContractedNodes(const std::unordered_map<int, BriefNode *> &node_map,
89+
int src_id, int dst_id) {
90+
// merge the two adjacent nodes into one node.
91+
BriefNode *src_node = node_map.at(src_id);
92+
BriefNode *dst_node = node_map.at(dst_id);
93+
94+
std::unordered_set<BriefNode *> inputs(src_node->inlinks.begin(),
95+
src_node->inlinks.end());
96+
std::unordered_set<BriefNode *> outputs;
97+
98+
for (auto *n : src_node->outlinks) {
99+
if (n != dst_node) outputs.insert(n);
100+
}
101+
102+
// Add the inlinks and outlinks of dst node to src node.
103+
std::vector<BriefNode *> dst_in_nodes = dst_node->inlinks;
104+
for (BriefNode *node : dst_in_nodes) {
105+
if (node != src_node) {
106+
inputs.insert(node);
107+
}
108+
}
109+
110+
std::vector<BriefNode *> dst_out_nodes = dst_node->outlinks;
111+
for (BriefNode *node : dst_out_nodes) {
112+
outputs.insert(node);
113+
}
114+
115+
// update the dst and src node's inlinks and outlinks.
116+
src_node->inlinks =
117+
std::move(std::vector<BriefNode *>(inputs.begin(), inputs.end()));
118+
src_node->outlinks =
119+
std::move(std::vector<BriefNode *>(outputs.begin(), outputs.end()));
120+
dst_node->inlinks.clear();
121+
dst_node->outlinks.clear();
122+
123+
auto inlink_or_outlink_cleaner = [&](std::vector<BriefNode *> &nodes) {
124+
for (auto *&n : nodes) {
125+
if (n == src_node || n == dst_node) {
126+
n = src_node;
127+
}
128+
}
129+
};
130+
// Change all the dst inputs and outputs corresponding inlink and
131+
// outlink to the src node.
132+
for (auto *node : src_node->inlinks) {
133+
inlink_or_outlink_cleaner(node->outlinks);
134+
}
135+
136+
for (auto *node : src_node->outlinks) {
137+
inlink_or_outlink_cleaner(node->inlinks);
138+
}
139+
}
140+
141+
// FlexibleDfS
142+
// If reverse is true, do reverse dfs.
143+
// If enter func is not nullptr, calls enter(node) before visiting any children
144+
// of node.
145+
// If leave func not nullptr, calls leave(node) after visiting all parents of
146+
// node.
147+
void FlexibleDFS(const std::vector<BriefNode *> &source, bool reverse,
148+
const std::function<bool(const BriefNode *)> &enter,
149+
const std::function<bool(const BriefNode *)> &leave) {
150+
typedef struct {
151+
const BriefNode *node;
152+
bool leave;
153+
} FNode;
154+
155+
std::vector<FNode> stack;
156+
for (auto &node : source) {
157+
stack.push_back(FNode{node, false});
158+
}
159+
std::unordered_set<const BriefNode *> visited;
160+
while (!stack.empty()) {
161+
auto fnode = stack.back();
162+
stack.pop_back();
163+
164+
if (fnode.leave) {
165+
if (leave && !leave(fnode.node)) return;
166+
}
167+
if (visited.count(fnode.node)) continue;
168+
visited.insert(fnode.node);
169+
170+
if (enter && !enter(fnode.node)) return;
171+
172+
if (leave) stack.push_back(FNode{fnode.node, true});
173+
const std::vector<BriefNode *> iter_nodes =
174+
reverse == true ? fnode.node->inlinks : fnode.node->outlinks;
175+
for (const BriefNode *node : iter_nodes) {
176+
if (!visited.count(node)) {
177+
stack.push_back(FNode{node, false});
178+
}
179+
}
180+
}
181+
}
182+
77183
std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
184+
// Run the Extract algorithm to find all subgraphs.
78185
std::vector<Node *> marked_nodes;
186+
// We use brief_node_map to represent the original graph in order to avoid
187+
// changing the original graph.
188+
std::unordered_map<int, BriefNode *> brief_node_map;
189+
79190
for (auto &node : GraphTraits<DataFlowGraph>(*graph_).nodes_in_TS()) {
191+
brief_node_map[node.id()] = new BriefNode(&node);
80192
if (node.attr(kMarkerAttrName).Bool()) {
81193
marked_nodes.push_back(&node);
82194
}
83195
}
196+
84197
// extract sub-graphs in the marked node set, use Union Find algorithm.
85198
node_map_t node_map; // id to ptr
86199
for (auto *n : marked_nodes) {
87200
// n's parent == n.id means it is the ancestor
88201
n->attr(kUnionFindParent).Int32() = n->id();
89202
node_map[n->id()] = n;
90203
}
91-
std::unordered_set<Node *> visited;
92-
for (auto *n : marked_nodes) {
93-
for (auto *out : n->outlinks) {
94-
if (node_map.count(out->id())) {
95-
UnionFindCombine(node_map, n->id(), out->id());
204+
205+
// create breif node map
206+
for (auto &itr : brief_node_map) {
207+
for (Node *node : itr.second->node->inlinks) {
208+
itr.second->inlinks.push_back(brief_node_map[node->id()]);
209+
}
210+
211+
for (Node *node : itr.second->node->outlinks) {
212+
itr.second->outlinks.push_back(brief_node_map[node->id()]);
213+
}
214+
}
215+
216+
for (auto &itr : brief_node_map) {
217+
BriefNode *brief_node = itr.second;
218+
219+
if (!brief_node->node->attr(kMarkerAttrName).Bool()) {
220+
VLOG(4) << brief_node->node->id() << " node not a trt candicate.";
221+
continue;
222+
}
223+
224+
// Our algorithm must guarantee that:
225+
// 1. The graph is always directed acyclic graph(DAG).
226+
// 2. If there is a path in the subgraph from X to Y (X and Y are both
227+
// nodes
228+
// in the subgraph), then all paths from X to Y are in the subgraph.
229+
//
230+
// In order to achieve the above guarantee.
231+
// For adjacent nodes src -> dst.
232+
// 1. Get all dst input nodes except src.
233+
// 2. Reverse DFS from those input nodes
234+
// 3. If there is a path from input nodes to src,
235+
// then the src and dst nodes can not be fused into one node,
236+
// otherwise it can be done.
237+
238+
while (true) {
239+
std::unordered_set<BriefNode *> contract_nodes;
240+
for (auto *out : brief_node->outlinks) {
241+
// must be an trt candidate
242+
if (!out->node->attr(kMarkerAttrName).Bool()) continue;
243+
// get all dst input nodes except src.
244+
std::vector<BriefNode *> source_nodes;
245+
for (auto *n : out->inlinks) {
246+
if (n != brief_node) {
247+
source_nodes.push_back(n);
248+
}
249+
}
250+
251+
// Reverse DFS from the source_nodes.
252+
bool have_excess_path = false;
253+
FlexibleDFS(source_nodes, true, nullptr,
254+
[&have_excess_path, brief_node](const BriefNode *n) {
255+
if (n == brief_node) {
256+
have_excess_path = true;
257+
return false;
258+
}
259+
return true;
260+
});
261+
if (have_excess_path) continue;
262+
contract_nodes.insert(out);
263+
}
264+
if (contract_nodes.empty()) break;
265+
266+
for (auto dst_node : contract_nodes) {
267+
UnionFindCombine(node_map, brief_node->node->id(),
268+
dst_node->node->id());
269+
UnionContractedNodes(brief_node_map, brief_node->node->id(),
270+
dst_node->node->id());
96271
}
97272
}
98273
}
@@ -128,6 +303,7 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() {
128303
auto io = ExtractInputAndOutputOfSubGraph(subgraph);
129304
block_node->inlinks = std::move(io.first);
130305
block_node->outlinks = std::move(io.second);
306+
131307
for (auto *node : subgraph) {
132308
// TODO(Superjomn) need a unified mechanism to treat deleted node in each
133309
// pass.

paddle/fluid/inference/analysis/subgraph_splitter_tester.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ TEST(SubGraphSplitter, Fuse) {
8282

8383
// At least one nodes should be deleted.
8484
ASSERT_EQ(dfg.nodes.size(), count0 + 1); // added a new FunctionBlock
85-
ASSERT_EQ(6, count1);
85+
ASSERT_EQ(11, count1);
8686
}
8787

8888
} // namespace analysis

paddle/fluid/operators/tensorrt_engine_op.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,11 +160,21 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
160160
fluid_t->mutable_data<float>(platform::CUDAPlace(
161161
boost::get<platform::CUDAPlace>(context.GetPlace()).device)),
162162
size * sizeof(float));
163-
//} else {
164-
// engine->GetOutputInGPU(
165-
// y, fluid_t->mutable_data<float>(platform::CUDAPlace()),
166-
// size * sizeof(float));
167-
//}
163+
164+
// TODO(zhaolong) : delete it sometimes
165+
/* THIS CODE JUST FOR TEST
166+
std::cout << output_maps[output_index] << std::endl;
167+
platform::CPUPlace cpu_place;
168+
framework::LoDTensor temp_tensor;
169+
temp_tensor.Resize(framework::make_ddim(ddim));
170+
auto* temp_data = temp_tensor.mutable_data<float>(cpu_place);
171+
172+
TensorCopySync(*fluid_t, cpu_place ,&temp_tensor);
173+
for(int i = 0; i < size; i++) {
174+
std::cout << temp_data[i] << " " ;
175+
}
176+
std::cout << std::endl;
177+
*/
168178
output_index += 1;
169179
}
170180

0 commit comments

Comments
 (0)