Skip to content

Commit b42ced8

Browse files
authored
bugfix/tensorrt analysis fix subgraph trigger (#12266)
1 parent c5c17a1 commit b42ced8

27 files changed

+342
-188
lines changed

paddle/fluid/inference/analysis/analyzer.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,16 @@
2222
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h"
2323

2424
namespace paddle {
25-
namespace inference {
26-
namespace analysis {
2725

2826
DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false,
2927
"Enable subgraph to TensorRT engine for acceleration");
3028

3129
DEFINE_string(inference_analysis_graphviz_log_root, "./",
3230
"Graphviz debuger for data flow graphs.");
3331

32+
namespace inference {
33+
namespace analysis {
34+
3435
class DfgPassManagerImpl final : public DfgPassManager {
3536
public:
3637
DfgPassManagerImpl() {

paddle/fluid/inference/analysis/analyzer.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,15 @@ limitations under the License. */
4545
#include "paddle/fluid/inference/analysis/pass_manager.h"
4646

4747
namespace paddle {
48-
namespace inference {
49-
namespace analysis {
5048

5149
// TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this
5250
// flag if not available.
5351
DECLARE_bool(inference_analysis_enable_tensorrt_subgraph_engine);
5452
DECLARE_string(inference_analysis_graphviz_log_root);
5553

54+
namespace inference {
55+
namespace analysis {
56+
5657
class Analyzer : public OrderedRegistry<PassManager> {
5758
public:
5859
// Register all the pass-managers.

paddle/fluid/inference/analysis/analyzer_tester.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,21 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/inference/analysis/analyzer.h"
16+
#include <google/protobuf/text_format.h>
1617
#include "paddle/fluid/inference/analysis/ut_helper.h"
1718

1819
namespace paddle {
1920
namespace inference {
2021
namespace analysis {
2122

22-
TEST_F(DFG_Tester, main) {
23+
TEST_F(DFG_Tester, analysis_without_tensorrt) {
24+
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = false;
25+
Analyzer analyser;
26+
analyser.Run(&argument);
27+
}
28+
29+
TEST_F(DFG_Tester, analysis_with_tensorrt) {
30+
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = true;
2331
Analyzer analyser;
2432
analyser.Run(&argument);
2533
}

paddle/fluid/inference/analysis/data_flow_graph.cc

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,17 +222,31 @@ Node *GraphTraits<DataFlowGraph>::NodesDFSIterator::operator->() {
222222
return stack_.top();
223223
}
224224

225+
inline bool CheckNodeIndegreeEquals(const Node &node, size_t n) {
226+
return node.inlinks.size() == n;
227+
}
228+
225229
GraphTraits<DataFlowGraph>::NodesTSIterator::NodesTSIterator(
226230
const std::vector<Node *> &source) {
227231
PADDLE_ENFORCE(!source.empty(),
228232
"Start points of topological sorting should not be empty!");
233+
// CHECK all the inputs' in-degree is 0
234+
for (auto *node : source) {
235+
PADDLE_ENFORCE(CheckNodeIndegreeEquals(*node, 0));
236+
}
237+
229238
std::unordered_set<Node *> visited;
230239
std::unordered_set<Node *> to_visit{source.begin(), source.end()};
231240

232241
std::vector<Node *> inlink_visited;
233242
while (!to_visit.empty()) {
234243
std::vector<Node *> queue(to_visit.begin(), to_visit.end());
235244
for (auto *p : queue) {
245+
if (p->deleted()) {
246+
visited.insert(p);
247+
to_visit.erase(p);
248+
continue;
249+
}
236250
inlink_visited.clear();
237251

238252
std::copy_if(p->inlinks.begin(), p->inlinks.end(),
@@ -292,6 +306,37 @@ Node *GraphTraits<DataFlowGraph>::NodesTSIterator::operator->() {
292306
return sorted_[cursor_];
293307
}
294308

309+
std::pair<std::vector<Node *>, std::vector<Node *>>
310+
ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
311+
std::unordered_set<Node *> nodes(graph.begin(), graph.end());
312+
std::unordered_set<Node *> inputs;
313+
std::unordered_set<Node *> outputs;
314+
// Input a Value, check whether its inlink is in the subgraph.
315+
auto inlink_in_subgraph = [&](Node *n) {
316+
for (auto *in : n->inlinks) {
317+
if (nodes.count(in)) return true;
318+
}
319+
return false;
320+
};
321+
for (auto &node : graph) {
322+
for (auto *in : node->inlinks) {
323+
// The Value that is written by nodes inside a sub-graph shouldn't be the
324+
// input of the sub-graph.
325+
if (!nodes.count(in) && in->type() == Node::Type::kValue &&
326+
!inlink_in_subgraph(in)) {
327+
inputs.insert(in);
328+
}
329+
}
330+
for (auto *out : node->outlinks) {
331+
if (!nodes.count(out) && out->type() == Node::Type::kValue) {
332+
outputs.insert(out);
333+
}
334+
}
335+
}
336+
return std::make_pair(std::vector<Node *>(inputs.begin(), inputs.end()),
337+
std::vector<Node *>(outputs.begin(), outputs.end()));
338+
}
339+
295340
} // namespace analysis
296341
} // namespace inference
297342
} // namespace paddle

paddle/fluid/inference/analysis/data_flow_graph.h

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ struct GraphTraits<DataFlowGraph> {
133133

134134
private:
135135
std::vector<Node *> sorted_;
136-
int cursor_{0};
136+
size_t cursor_{0};
137137
};
138138

139139
explicit GraphTraits(DataFlowGraph *graph) : graph_(graph) {}
@@ -173,36 +173,8 @@ struct GraphTraits<DataFlowGraph> {
173173
// Extract the inputs and outputs of a graph. The inputs and outputs of a
174174
// sub-graph is the inputs nodes and output nodes that doesn't inside the
175175
// sub-graph.
176-
static std::pair<std::vector<Node *>, std::vector<Node *>>
177-
ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
178-
std::unordered_set<Node *> nodes(graph.begin(), graph.end());
179-
std::unordered_set<Node *> inputs;
180-
std::unordered_set<Node *> outputs;
181-
// Input a Value, check whether its inlink is in the subgraph.
182-
auto inlink_in_subgraph = [&](Node *n) {
183-
for (auto *in : n->inlinks) {
184-
if (nodes.count(in)) return true;
185-
}
186-
return false;
187-
};
188-
for (auto &node : graph) {
189-
for (auto *in : node->inlinks) {
190-
// The Value that is written by nodes inside a sub-graph shouldn't be the
191-
// input of the sub-graph.
192-
if (!nodes.count(in) && in->type() == Node::Type::kValue &&
193-
!inlink_in_subgraph(in)) {
194-
inputs.insert(in);
195-
}
196-
}
197-
for (auto *out : node->outlinks) {
198-
if (!nodes.count(out) && out->type() == Node::Type::kValue) {
199-
outputs.insert(out);
200-
}
201-
}
202-
}
203-
return std::make_pair(std::vector<Node *>(inputs.begin(), inputs.end()),
204-
std::vector<Node *>(outputs.begin(), outputs.end()));
205-
}
176+
std::pair<std::vector<Node *>, std::vector<Node *>>
177+
ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph);
206178

207179
} // namespace analysis
208180
} // namespace inference

paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,18 @@
2222

2323
namespace paddle {
2424
namespace inference {
25+
26+
DEFINE_int32(tensorrt_max_batchsize, 300, "TensorRT maximum batch size");
27+
DEFINE_int32(tensorrt_workspace_size, 2048, "TensorRT workspace size");
28+
2529
namespace analysis {
2630

2731
using framework::proto::ProgramDesc;
2832

2933
std::vector<std::string> ExtractParameters(
30-
const std::vector<std::unique_ptr<Node>>& nodes);
34+
const std::vector<std::unique_ptr<Node>> &nodes);
3135

32-
bool DataFlowGraphToFluidPass::Initialize(Argument* argument) {
36+
bool DataFlowGraphToFluidPass::Initialize(Argument *argument) {
3337
ANALYSIS_ARGUMENT_CHECK_FIELD(argument)
3438
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc)
3539
PADDLE_ENFORCE(!argument->transformed_program_desc);
@@ -47,76 +51,77 @@ bool DataFlowGraphToFluidPass::Initialize(Argument* argument) {
4751

4852
bool DataFlowGraphToFluidPass::Finalize() { return true; }
4953

50-
void DataFlowGraphToFluidPass::Run(DataFlowGraph* graph) {
51-
auto traits = GraphTraits<DataFlowGraph>(graph);
52-
for (auto it = traits.nodes().begin(); it != traits.nodes().end(); ++it) {
53-
if (it->deleted()) continue;
54+
void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
55+
LOG(INFO) << "graph.inputs " << graph->inputs.size();
56+
for (auto &node : GraphTraits<DataFlowGraph>(graph).nodes_in_TS()) {
57+
if (node.deleted()) continue;
5458

55-
switch (it->type()) {
59+
switch (node.type()) {
5660
case Node::Type::kFunction: {
57-
LOG(INFO) << "add function " << it->repr();
58-
AddFluidOp(&(*it));
61+
LOG(INFO) << "add function " << node.repr();
62+
AddFluidOp(&node);
5963
} break;
6064
case Node::Type::kFunctionBlock: {
61-
LOG(INFO) << "add engine op " << it->repr() << " , "
62-
<< static_cast<FunctionBlock*>(&(*it))->subgraph.size();
63-
AddEngineOp(&(*it));
65+
LOG(INFO) << "add engine op " << node.repr() << " , "
66+
<< static_cast<FunctionBlock *>(&node)->subgraph.size();
67+
AddEngineOp(&node);
6468
} break;
6569
default:
6670
continue;
6771
}
6872
}
73+
74+
PADDLE_ENFORCE(argument_->transformed_program_desc.get());
6975
}
7076

71-
void DataFlowGraphToFluidPass::AddFluidOp(Node* node) {
72-
auto* ori_op = static_cast<framework::proto::OpDesc*>(node->pb_desc());
77+
void DataFlowGraphToFluidPass::AddFluidOp(Node *node) {
78+
auto *ori_op = static_cast<framework::proto::OpDesc *>(node->pb_desc());
7379
// currently only the main block is analyzed.
74-
auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
75-
auto* op = main_block->add_ops();
80+
auto *main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
81+
auto *op = main_block->add_ops();
7682
*op = *ori_op; // copy the attributes, by default, these will not be changed
77-
// by analysis phrase.
83+
// by analysis phrase.
7884
// The inputs and outputs of the existing ops are not changed by tensorrt
7985
// subgraph pass.
8086
// NOTE It might be changed by other passes in the long run.
8187
}
8288

83-
void CreateTrtEngineOp(Node* node, const DataFlowGraph& graph,
84-
const framework::proto::BlockDesc& block) {
89+
void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
90+
const framework::proto::BlockDesc &block) {
8591
static int counter{0};
8692
PADDLE_ENFORCE(node->IsFunctionBlock());
8793
framework::OpDesc desc;
88-
auto* func = static_cast<FunctionBlock*>(node);
94+
auto *func = static_cast<FunctionBlock *>(node);
8995

9096
// collect inputs
9197
std::vector<std::string> io;
92-
for (auto* x : func->inlinks) {
98+
for (auto *x : func->inlinks) {
9399
io.push_back(x->name());
94100
}
95101
desc.SetInput("Xs", io);
96102

97103
// collect outputs
98104
io.clear();
99-
for (auto* x : func->outlinks) {
105+
for (auto *x : func->outlinks) {
100106
io.push_back(x->name());
101107
}
102108
desc.SetOutput("Ys", io);
103-
104109
desc.SetType("tensorrt_engine");
110+
111+
PADDLE_ENFORCE(!block.vars().empty(), "the block has no var-desc");
105112
// Set attrs
106113
SetAttr(desc.Proto(), "subgraph", block.SerializeAsString());
107-
SetAttr(desc.Proto(), "engine_unique_key",
108-
"trt-" + std::to_string(counter++));
109-
SetAttr(desc.Proto(), "max_batch", 100); // TODO(Superjomn) add config latter
110-
SetAttr(desc.Proto(), "max_workspace",
111-
1024); // TODO(Superjomn) add config latter
114+
SetAttr(desc.Proto(), "engine_uniq_key", "trt-" + std::to_string(counter++));
115+
SetAttr(desc.Proto(), "max_batch", FLAGS_tensorrt_max_batchsize);
116+
SetAttr(desc.Proto(), "max_workspace", FLAGS_tensorrt_workspace_size);
112117
SetAttr(desc.Proto(), "parameters", ExtractParameters(graph.nodes.nodes()));
113118
node->SetPbMsg(desc.Proto()->SerializeAsString());
114119
}
115120

116121
std::vector<std::string> ExtractParameters(
117-
const std::vector<std::unique_ptr<Node>>& nodes) {
122+
const std::vector<std::unique_ptr<Node>> &nodes) {
118123
std::vector<std::string> parameters;
119-
for (const auto& node : nodes) {
124+
for (const auto &node : nodes) {
120125
if (!node->IsValue()) continue;
121126
PADDLE_ENFORCE(!node->pb_msg().empty(), "pb_msg should be set first");
122127
framework::proto::VarDesc var;
@@ -128,21 +133,30 @@ std::vector<std::string> ExtractParameters(
128133
return parameters;
129134
}
130135

131-
void DataFlowGraphToFluidPass::AddEngineOp(Node* node) {
136+
void DataFlowGraphToFluidPass::AddEngineOp(Node *node) {
132137
// TODO(Superjomn) Here need to expose some arguments for default setting.
133138
PADDLE_ENFORCE(node->IsFunctionBlock());
134-
auto* block_node = static_cast<FunctionBlock*>(node);
139+
auto *block_node = static_cast<FunctionBlock *>(node);
135140
framework::proto::BlockDesc proto;
136141
framework::BlockDesc block_desc(nullptr, &proto);
142+
block_desc.Proto()->set_parent_idx(-1);
143+
block_desc.Proto()->set_idx(0);
144+
LOG(INFO) << "origin variable size: "
145+
<< argument_->origin_program_desc->blocks(0).vars().size();
146+
LOG(INFO) << "transformed variable size: "
147+
<< block_desc.Proto()->vars().size();
137148
// copy ops.
138-
for (auto* node : block_node->subgraph) {
139-
auto* op = block_desc.AppendOp();
149+
for (auto *node : block_node->subgraph) {
150+
auto *op = block_desc.AppendOp();
140151
PADDLE_ENFORCE(!node->pb_msg().empty());
141152
op->Proto()->ParseFromString(node->pb_msg());
142153
}
154+
*block_desc.Proto()->mutable_vars() =
155+
argument_->origin_program_desc->blocks(0).vars();
156+
PADDLE_ENFORCE(!block_desc.Proto()->vars().empty());
143157
CreateTrtEngineOp(node, *argument_->main_dfg, *block_desc.Proto());
144-
auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
145-
auto* op = main_block->add_ops();
158+
auto *main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
159+
auto *op = main_block->add_ops();
146160
PADDLE_ENFORCE(!node->pb_msg().empty(), "failed to set desc for block");
147161
op->ParseFromString(node->pb_msg());
148162
}
@@ -151,7 +165,7 @@ namespace {
151165
class DFG_DebuggerPass : public DFG_GraphvizDrawPass {
152166
public:
153167
using Config = DFG_GraphvizDrawPass::Config;
154-
explicit DFG_DebuggerPass(const Config& config)
168+
explicit DFG_DebuggerPass(const Config &config)
155169
: DFG_GraphvizDrawPass(config) {}
156170

157171
std::string repr() const override { return "dfg-to-fluid-debuger-pass"; }
@@ -160,7 +174,7 @@ class DFG_DebuggerPass : public DFG_GraphvizDrawPass {
160174
};
161175
} // namespace
162176

163-
Pass* DataFlowGraphToFluidPass::CreateGraphvizDebugerPass() const {
177+
Pass *DataFlowGraphToFluidPass::CreateGraphvizDebugerPass() const {
164178
return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config(
165179
FLAGS_inference_analysis_graphviz_log_root,
166180
"data_flow_graph_to_fluid_graphviz_debugger"));

paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626

2727
namespace paddle {
2828
namespace inference {
29+
30+
DECLARE_int32(tensorrt_max_batchsize);
31+
DECLARE_int32(tensorrt_workspace_size);
32+
2933
namespace analysis {
3034
class DataFlowGraphToFluidPass final : public DataFlowGraphPass {
3135
public:

paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) {
4040
no++;
4141
}
4242
// DFG is sensitive to ProgramDesc, be careful to change the existing models.
43-
ASSERT_EQ(no, 82);
43+
ASSERT_EQ(no, 83);
4444
}
4545

4646
} // namespace analysis

0 commit comments

Comments
 (0)