Skip to content

Commit c999528

Browse files
authored
Merge pull request #13124 from NHZlX/fix_subgraph_bug
Fix tensorrt subgraph bug
2 parents d4a5326 + 8fb33c8 commit c999528

File tree

12 files changed

+227
-22
lines changed

12 files changed

+227
-22
lines changed

paddle/fluid/inference/analysis/data_flow_graph.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
440440
}
441441
return false;
442442
};
443+
443444
for (auto &node : graph) {
444445
for (auto *in : node->inlinks) {
445446
// The Value that is written by nodes inside a sub-graph shouldn't be the
@@ -459,6 +460,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
459460
std::vector<Node *>(outputs.begin(), outputs.end()));
460461
}
461462

463+
// Filter the Intermediate results of the subgraph node.
462464
void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
463465
std::vector<Node *> op_nodes;
464466
for (auto &node : GraphTraits<DataFlowGraph>(*graph).nodes_in_TS()) {
@@ -480,9 +482,11 @@ void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
480482
for (auto *out : op_nodes[i]->outlinks) {
481483
if (follow_up_input_names.count(out->name())) {
482484
filtered_subgraph_outlinks.push_back(out);
485+
} else {
486+
out->SetDeleted();
483487
}
484488
}
485-
PADDLE_ENFORCE_GE(filtered_subgraph_outlinks.size(), 1UL);
489+
// The filtered_subgraph_outlinks may be empty.
486490
op_nodes[i]->outlinks = filtered_subgraph_outlinks;
487491
}
488492
}

paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,20 +106,23 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
106106

107107
// collect inputs
108108
std::unordered_set<std::string> input_names;
109+
std::unordered_set<std::string> input_names_with_id;
109110
for (auto *x : func->inlinks) {
110111
input_names.insert(x->name());
112+
input_names_with_id.insert(x->name() + std::to_string(x->id()));
111113
}
112114
desc.SetInput(
113115
"Xs", std::vector<std::string>(input_names.begin(), input_names.end()));
114116

115117
std::unordered_set<std::string> output_names;
118+
std::unordered_set<std::string> output_names_with_id;
116119
for (auto *x : func->outlinks) {
117120
output_names.insert(x->name());
121+
output_names_with_id.insert(x->name() + std::to_string(x->id()));
118122
}
119123

120-
std::vector<std::string> output_temp(output_names.begin(),
121-
output_names.end());
122-
desc.SetOutput("Ys", output_temp);
124+
desc.SetOutput(
125+
"Ys", std::vector<std::string>(output_names.begin(), output_names.end()));
123126
desc.SetType("tensorrt_engine");
124127

125128
std::unordered_map<std::string, std::string> output_name_map;
@@ -153,11 +156,12 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
153156
std::vector<std::string> replaced_names;
154157
for (int k = 0; k < in_var->arguments_size(); k++) {
155158
std::string arg_value = in_var->arguments(k);
156-
if (input_names.count(arg_value)) {
159+
std::string arg_value_with_id =
160+
arg_value + std::to_string(var2id[arg_value]);
161+
if (input_names_with_id.count(arg_value_with_id)) {
157162
replaced_names.push_back(arg_value);
158163
} else {
159-
replaced_names.push_back(arg_value +
160-
std::to_string(var2id[arg_value]));
164+
replaced_names.push_back(arg_value_with_id);
161165
}
162166
}
163167
in_var->clear_arguments();
@@ -176,11 +180,12 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
176180
std::vector<std::string> replaced_names;
177181
for (int k = 0; k < out_var->arguments_size(); k++) {
178182
std::string arg_value = out_var->arguments(k);
179-
if (output_names.count(arg_value)) {
180-
output_name_map[arg_value] =
181-
arg_value + std::to_string(var2id[arg_value]);
183+
std::string arg_value_with_id =
184+
arg_value + std::to_string(var2id[arg_value]);
185+
if (output_names_with_id.count(arg_value_with_id)) {
186+
output_name_map[arg_value] = arg_value_with_id;
182187
}
183-
replaced_names.push_back(arg_value + std::to_string(var2id[arg_value]));
188+
replaced_names.push_back(arg_value_with_id);
184189
}
185190
out_var->clear_arguments();
186191
for (size_t k = 0; k < replaced_names.size(); k++) {

paddle/fluid/inference/analysis/subgraph_splitter.cc

Lines changed: 189 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,25 +74,208 @@ void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) {
7474
node_map.at(b)->attr(kUnionFindParent).Int32() = a_ancestor;
7575
}
7676

77+
// This is a simple representation of a graph.
78+
// The BriefNode hold the pointer of the Node.
79+
// This is to avoid changing the original graph
80+
// in the process of trt graph analysis.
81+
struct BriefNode {
82+
explicit BriefNode(Node *n) { node = n; }
83+
Node *node;
84+
std::vector<BriefNode *> inlinks;
85+
std::vector<BriefNode *> outlinks;
86+
};
87+
88+
// Union two adjacent BriefNode.
89+
// Suppose we have two adjacent nodes src and dst.
90+
// We will perform the following operations:
91+
// 1. add all inputs(except src) of dst to src inlinks.
92+
// 2. add all outputs of dst to src outlinks.
93+
// 3. change all the dst's inputs and outputs
94+
// corresponding inlinks and outlinks to src node.
95+
// 4. delete all dst's inlinks and outlinks.
96+
void UnionContractedNodes(const std::unordered_map<int, BriefNode *> &node_map,
97+
int src_id, int dst_id) {
98+
// merge the two adjacent nodes into one node.
99+
BriefNode *src_node = node_map.at(src_id);
100+
BriefNode *dst_node = node_map.at(dst_id);
101+
102+
std::unordered_set<BriefNode *> inputs(src_node->inlinks.begin(),
103+
src_node->inlinks.end());
104+
std::unordered_set<BriefNode *> outputs;
105+
106+
for (auto *n : src_node->outlinks) {
107+
if (n != dst_node) outputs.insert(n);
108+
}
109+
110+
// Add the inlinks and outlinks of dst node to src node.
111+
std::vector<BriefNode *> dst_in_nodes = dst_node->inlinks;
112+
for (BriefNode *node : dst_in_nodes) {
113+
if (node != src_node) {
114+
inputs.insert(node);
115+
}
116+
}
117+
118+
std::vector<BriefNode *> dst_out_nodes = dst_node->outlinks;
119+
for (BriefNode *node : dst_out_nodes) {
120+
outputs.insert(node);
121+
}
122+
123+
// update the dst and src node's inlinks and outlinks.
124+
src_node->inlinks =
125+
std::move(std::vector<BriefNode *>(inputs.begin(), inputs.end()));
126+
src_node->outlinks =
127+
std::move(std::vector<BriefNode *>(outputs.begin(), outputs.end()));
128+
dst_node->inlinks.clear();
129+
dst_node->outlinks.clear();
130+
131+
auto inlink_or_outlink_cleaner = [&](std::vector<BriefNode *> &nodes) {
132+
for (auto *&n : nodes) {
133+
if (n == src_node || n == dst_node) {
134+
n = src_node;
135+
}
136+
}
137+
};
138+
// Change all the dst inputs and outputs corresponding inlink and
139+
// outlink to the src node.
140+
for (auto *node : src_node->inlinks) {
141+
inlink_or_outlink_cleaner(node->outlinks);
142+
}
143+
144+
for (auto *node : src_node->outlinks) {
145+
inlink_or_outlink_cleaner(node->inlinks);
146+
}
147+
}
148+
149+
// FlexibleDFS
150+
// If reverse is true, do reverse dfs.
151+
// If enter func is not nullptr, calls enter(node) before visiting any children
152+
// of node.
153+
// If leave func not nullptr, calls leave(node) after visiting all parents of
154+
// node.
155+
void FlexibleDFS(const std::vector<BriefNode *> &source, bool reverse,
156+
const std::function<bool(const BriefNode *)> &enter,
157+
const std::function<bool(const BriefNode *)> &leave) {
158+
typedef struct {
159+
const BriefNode *node;
160+
bool leave;
161+
} FNode;
162+
163+
std::vector<FNode> stack;
164+
for (auto &node : source) {
165+
stack.push_back(FNode{node, false});
166+
}
167+
std::unordered_set<const BriefNode *> visited;
168+
while (!stack.empty()) {
169+
auto fnode = stack.back();
170+
stack.pop_back();
171+
172+
if (fnode.leave) {
173+
if (leave && !leave(fnode.node)) return;
174+
}
175+
if (visited.count(fnode.node)) continue;
176+
visited.insert(fnode.node);
177+
178+
if (enter && !enter(fnode.node)) return;
179+
180+
if (leave) stack.push_back(FNode{fnode.node, true});
181+
const std::vector<BriefNode *> iter_nodes =
182+
reverse == true ? fnode.node->inlinks : fnode.node->outlinks;
183+
for (const BriefNode *node : iter_nodes) {
184+
if (!visited.count(node)) {
185+
stack.push_back(FNode{node, false});
186+
}
187+
}
188+
}
189+
}
190+
77191
std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
192+
// Run the Extract algorithm to find all subgraphs.
78193
std::vector<Node *> marked_nodes;
194+
// We use brief_node_map to represent the original graph in order to avoid
195+
// changing the original graph.
196+
std::unordered_map<int, BriefNode *> brief_node_map;
197+
79198
for (auto &node : GraphTraits<DataFlowGraph>(*graph_).nodes_in_TS()) {
199+
brief_node_map[node.id()] = new BriefNode(&node);
80200
if (node.attr(kMarkerAttrName).Bool()) {
81201
marked_nodes.push_back(&node);
82202
}
83203
}
204+
84205
// extract sub-graphs in the marked node set, use Union Find algorithm.
85206
node_map_t node_map; // id to ptr
86207
for (auto *n : marked_nodes) {
87208
// n's parent == n.id means it is the ancestor
88209
n->attr(kUnionFindParent).Int32() = n->id();
89210
node_map[n->id()] = n;
90211
}
91-
std::unordered_set<Node *> visited;
92-
for (auto *n : marked_nodes) {
93-
for (auto *out : n->outlinks) {
94-
if (node_map.count(out->id())) {
95-
UnionFindCombine(node_map, n->id(), out->id());
212+
213+
// create breif node map
214+
for (auto &itr : brief_node_map) {
215+
for (Node *node : itr.second->node->inlinks) {
216+
itr.second->inlinks.push_back(brief_node_map[node->id()]);
217+
}
218+
219+
for (Node *node : itr.second->node->outlinks) {
220+
itr.second->outlinks.push_back(brief_node_map[node->id()]);
221+
}
222+
}
223+
224+
for (auto &itr : brief_node_map) {
225+
BriefNode *brief_node = itr.second;
226+
227+
if (!brief_node->node->attr(kMarkerAttrName).Bool()) {
228+
VLOG(4) << brief_node->node->id() << " node not a trt candicate.";
229+
continue;
230+
}
231+
232+
// Our algorithm must guarantee that:
233+
// 1. The graph is always directed acyclic graph(DAG).
234+
// 2. If there is a path in the subgraph from X to Y (X and Y are both
235+
// nodes in the subgraph), then all paths from X to Y are in the
236+
// subgraph.
237+
//
238+
// In order to achieve the above guarantee.
239+
// For adjacent nodes src -> dst.
240+
// 1. Get all dst input nodes except src.
241+
// 2. Reverse DFS from those input nodes
242+
// 3. If there is a path from input nodes to src,
243+
// then the src and dst nodes can not be fused into one node,
244+
// otherwise it can be done.
245+
246+
while (true) {
247+
std::unordered_set<BriefNode *> contract_nodes;
248+
for (auto *out : brief_node->outlinks) {
249+
// must be an trt candidate
250+
if (!out->node->attr(kMarkerAttrName).Bool()) continue;
251+
// get all dst input nodes except src.
252+
std::vector<BriefNode *> source_nodes;
253+
for (auto *n : out->inlinks) {
254+
if (n != brief_node) {
255+
source_nodes.push_back(n);
256+
}
257+
}
258+
259+
// Reverse DFS from the source_nodes.
260+
bool have_excess_path = false;
261+
FlexibleDFS(source_nodes, true, nullptr,
262+
[&have_excess_path, brief_node](const BriefNode *n) {
263+
if (n == brief_node) {
264+
have_excess_path = true;
265+
return false;
266+
}
267+
return true;
268+
});
269+
if (have_excess_path) continue;
270+
contract_nodes.insert(out);
271+
}
272+
if (contract_nodes.empty()) break;
273+
274+
for (auto dst_node : contract_nodes) {
275+
UnionFindCombine(node_map, brief_node->node->id(),
276+
dst_node->node->id());
277+
UnionContractedNodes(brief_node_map, brief_node->node->id(),
278+
dst_node->node->id());
96279
}
97280
}
98281
}
@@ -128,6 +311,7 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() {
128311
auto io = ExtractInputAndOutputOfSubGraph(subgraph);
129312
block_node->inlinks = std::move(io.first);
130313
block_node->outlinks = std::move(io.second);
314+
131315
for (auto *node : subgraph) {
132316
// TODO(Superjomn) need a unified mechanism to treat deleted node in each
133317
// pass.

paddle/fluid/inference/analysis/subgraph_splitter_tester.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ TEST(SubGraphSplitter, Fuse) {
8282

8383
// At least one nodes should be deleted.
8484
ASSERT_EQ(dfg.nodes.size(), count0 + 1); // added a new FunctionBlock
85-
ASSERT_EQ(6, count1);
85+
ASSERT_EQ(11, count1);
8686
}
8787

8888
} // namespace analysis

paddle/fluid/inference/tensorrt/convert/activation_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class ReluOpConverter : public OpConverter {
3535
engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor),
3636
nvinfer1::ActivationType::kRELU);
3737
auto output_name = op_desc.Output("Out")[0];
38+
layer->setName(("relu (Output: " + output_name + ")").c_str());
39+
layer->getOutput(0)->setName(output_name.c_str());
3840
engine_->SetITensor(output_name, layer->getOutput(0));
3941
if (test_mode) { // the test framework can not determine which is the
4042
// output, so place the declaration inside.

paddle/fluid/inference/tensorrt/convert/batch_norm_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ class BatchNormOpConverter : public OpConverter {
116116
scale_weights.get(), power_weights.get());
117117

118118
auto output_name = op_desc.Output("Y").front();
119+
layer->setName(("batch_norm (Output: " + output_name + ")").c_str());
120+
layer->getOutput(0)->setName(output_name.c_str());
119121
engine_->weight_map[op_desc.Input("Bias").front()] =
120122
std::move(combile_bias_tensor);
121123
engine_->weight_map[op_desc.Input("Scale").front()] =

paddle/fluid/inference/tensorrt/convert/concat_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class ConcatOpConverter : public OpConverter {
4242
axis = axis - 1; // Remove batch dim
4343
layer->setAxis(axis);
4444
auto output_name = op_desc.Output("Out")[0];
45+
layer->setName(("concat (Output: " + output_name + ")").c_str());
46+
layer->getOutput(0)->setName(output_name.c_str());
4547
engine_->SetITensor(output_name, layer->getOutput(0));
4648
if (test_mode) { // the test framework can not determine which is the
4749
// output, so place the declaration inside.

paddle/fluid/inference/tensorrt/convert/conv2d_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,10 @@ class Conv2dOpConverter : public OpConverter {
7878
layer->setNbGroups(groups);
7979

8080
auto output_name = op_desc.Output("Output").front();
81+
layer->setName(("conv2d (Output: " + output_name + ")").c_str());
8182
engine_->weight_map[op_desc.Input("Filter").front()] =
8283
std::move(weight_tensor);
84+
layer->getOutput(0)->setName(output_name.c_str());
8385
engine_->SetITensor(output_name, layer->getOutput(0));
8486
if (test_mode) {
8587
engine_->DeclareOutput(output_name);

paddle/fluid/inference/tensorrt/convert/elementwise_op.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ class ElementwiseWeightOpConverter : public OpConverter {
8989
shift_weights.get(), scale_weights.get(), power_weights.get());
9090
auto output_name = op_desc.Output("Out")[0];
9191

92+
layer->setName(("elementwise_add (Output: " + output_name + ")").c_str());
93+
layer->getOutput(0)->setName(output_name.c_str());
9294
engine_->weight_map[op_desc.Input("Y").front()] = std::move(weight_tensor);
9395
engine_->SetITensor(output_name, layer->getOutput(0));
9496
if (test_mode) { // the test framework can not determine which is the
@@ -137,6 +139,8 @@ class ElementwiseTensorOpConverter : public OpConverter {
137139
*const_cast<nvinfer1::ITensor*>(Y), op_pair->second);
138140

139141
auto output_name = op_desc.Output("Out")[0];
142+
layer->setName(("elementwise (Output: " + output_name + ")").c_str());
143+
layer->getOutput(0)->setName(output_name.c_str());
140144
engine_->SetITensor(output_name, layer->getOutput(0));
141145
if (test_mode) { // the test framework can not determine which is the
142146
// output, so place the declaration inside.

paddle/fluid/inference/tensorrt/convert/fc_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ class FcOpConverter : public OpConverter {
107107
n_output, tmp_weight.get(), bias.get());
108108

109109
auto output_name = op_desc.Output("Out").front();
110+
layer->setName(("fc (Output: " + output_name + ")").c_str());
111+
layer->getOutput(0)->setName(output_name.c_str());
110112
engine_->SetITensor(output_name, layer->getOutput(0));
111113
engine_->weight_map[op_desc.Input("Y").front()] = std::move(tmp);
112114
if (test_mode) {

0 commit comments

Comments
 (0)