22
22
23
23
namespace paddle {
24
24
namespace inference {
25
+
26
+ DEFINE_int32 (tensorrt_max_batchsize, 300 , " TensorRT maximum batch size" );
27
+ DEFINE_int32 (tensorrt_workspace_size, 2048 , " TensorRT workspace size" );
28
+
25
29
namespace analysis {
26
30
27
31
using framework::proto::ProgramDesc;
28
32
29
33
std::vector<std::string> ExtractParameters (
30
- const std::vector<std::unique_ptr<Node>>& nodes);
34
+ const std::vector<std::unique_ptr<Node>> & nodes);
31
35
32
- bool DataFlowGraphToFluidPass::Initialize (Argument* argument) {
36
+ bool DataFlowGraphToFluidPass::Initialize (Argument * argument) {
33
37
ANALYSIS_ARGUMENT_CHECK_FIELD (argument)
34
38
ANALYSIS_ARGUMENT_CHECK_FIELD (argument->origin_program_desc )
35
39
PADDLE_ENFORCE (!argument->transformed_program_desc );
@@ -47,76 +51,77 @@ bool DataFlowGraphToFluidPass::Initialize(Argument* argument) {
47
51
48
52
bool DataFlowGraphToFluidPass::Finalize () { return true ; }
49
53
50
- void DataFlowGraphToFluidPass::Run (DataFlowGraph* graph) {
51
- auto traits = GraphTraits<DataFlowGraph>(graph );
52
- for (auto it = traits. nodes (). begin (); it != traits. nodes (). end (); ++it ) {
53
- if (it-> deleted ()) continue ;
54
+ void DataFlowGraphToFluidPass::Run (DataFlowGraph * graph) {
55
+ LOG (INFO) << " graph.inputs " << graph-> inputs . size ( );
56
+ for (auto &node : GraphTraits<DataFlowGraph>(graph). nodes_in_TS () ) {
57
+ if (node. deleted ()) continue ;
54
58
55
- switch (it-> type ()) {
59
+ switch (node. type ()) {
56
60
case Node::Type::kFunction : {
57
- LOG (INFO) << " add function " << it-> repr ();
58
- AddFluidOp (&(*it) );
61
+ LOG (INFO) << " add function " << node. repr ();
62
+ AddFluidOp (&node );
59
63
} break ;
60
64
case Node::Type::kFunctionBlock : {
61
- LOG (INFO) << " add engine op " << it-> repr () << " , "
62
- << static_cast <FunctionBlock*>(&(*it) )->subgraph .size ();
63
- AddEngineOp (&(*it) );
65
+ LOG (INFO) << " add engine op " << node. repr () << " , "
66
+ << static_cast <FunctionBlock *>(&node )->subgraph .size ();
67
+ AddEngineOp (&node );
64
68
} break ;
65
69
default :
66
70
continue ;
67
71
}
68
72
}
73
+
74
+ PADDLE_ENFORCE (argument_->transformed_program_desc .get ());
69
75
}
70
76
71
- void DataFlowGraphToFluidPass::AddFluidOp (Node* node) {
72
- auto * ori_op = static_cast <framework::proto::OpDesc*>(node->pb_desc ());
77
+ void DataFlowGraphToFluidPass::AddFluidOp (Node * node) {
78
+ auto * ori_op = static_cast <framework::proto::OpDesc *>(node->pb_desc ());
73
79
// currently only the main block is analyzed.
74
- auto * main_block = desc_->mutable_blocks (framework::kRootBlockIndex );
75
- auto * op = main_block->add_ops ();
80
+ auto * main_block = desc_->mutable_blocks (framework::kRootBlockIndex );
81
+ auto * op = main_block->add_ops ();
76
82
*op = *ori_op; // copy the attributes, by default, these will not be changed
77
- // by analysis phrase.
83
+ // by analysis phrase.
78
84
// The inputs and outputs of the existing ops are not changed by tensorrt
79
85
// subgraph pass.
80
86
// NOTE It might be changed by other passes in the long run.
81
87
}
82
88
83
- void CreateTrtEngineOp (Node* node, const DataFlowGraph& graph,
84
- const framework::proto::BlockDesc& block) {
89
+ void CreateTrtEngineOp (Node * node, const DataFlowGraph & graph,
90
+ const framework::proto::BlockDesc & block) {
85
91
static int counter{0 };
86
92
PADDLE_ENFORCE (node->IsFunctionBlock ());
87
93
framework::OpDesc desc;
88
- auto * func = static_cast <FunctionBlock*>(node);
94
+ auto * func = static_cast <FunctionBlock *>(node);
89
95
90
96
// collect inputs
91
97
std::vector<std::string> io;
92
- for (auto * x : func->inlinks ) {
98
+ for (auto * x : func->inlinks ) {
93
99
io.push_back (x->name ());
94
100
}
95
101
desc.SetInput (" Xs" , io);
96
102
97
103
// collect outputs
98
104
io.clear ();
99
- for (auto * x : func->outlinks ) {
105
+ for (auto * x : func->outlinks ) {
100
106
io.push_back (x->name ());
101
107
}
102
108
desc.SetOutput (" Ys" , io);
103
-
104
109
desc.SetType (" tensorrt_engine" );
110
+
111
+ PADDLE_ENFORCE (!block.vars ().empty (), " the block has no var-desc" );
105
112
// Set attrs
106
113
SetAttr (desc.Proto (), " subgraph" , block.SerializeAsString ());
107
- SetAttr (desc.Proto (), " engine_unique_key" ,
108
- " trt-" + std::to_string (counter++));
109
- SetAttr (desc.Proto (), " max_batch" , 100 ); // TODO(Superjomn) add config latter
110
- SetAttr (desc.Proto (), " max_workspace" ,
111
- 1024 ); // TODO(Superjomn) add config latter
114
+ SetAttr (desc.Proto (), " engine_uniq_key" , " trt-" + std::to_string (counter++));
115
+ SetAttr (desc.Proto (), " max_batch" , FLAGS_tensorrt_max_batchsize);
116
+ SetAttr (desc.Proto (), " max_workspace" , FLAGS_tensorrt_workspace_size);
112
117
SetAttr (desc.Proto (), " parameters" , ExtractParameters (graph.nodes .nodes ()));
113
118
node->SetPbMsg (desc.Proto ()->SerializeAsString ());
114
119
}
115
120
116
121
std::vector<std::string> ExtractParameters (
117
- const std::vector<std::unique_ptr<Node>>& nodes) {
122
+ const std::vector<std::unique_ptr<Node>> & nodes) {
118
123
std::vector<std::string> parameters;
119
- for (const auto & node : nodes) {
124
+ for (const auto & node : nodes) {
120
125
if (!node->IsValue ()) continue ;
121
126
PADDLE_ENFORCE (!node->pb_msg ().empty (), " pb_msg should be set first" );
122
127
framework::proto::VarDesc var;
@@ -128,21 +133,30 @@ std::vector<std::string> ExtractParameters(
128
133
return parameters;
129
134
}
130
135
131
- void DataFlowGraphToFluidPass::AddEngineOp (Node* node) {
136
+ void DataFlowGraphToFluidPass::AddEngineOp (Node * node) {
132
137
// TODO(Superjomn) Here need to expose some arguments for default setting.
133
138
PADDLE_ENFORCE (node->IsFunctionBlock ());
134
- auto * block_node = static_cast <FunctionBlock*>(node);
139
+ auto * block_node = static_cast <FunctionBlock *>(node);
135
140
framework::proto::BlockDesc proto;
136
141
framework::BlockDesc block_desc (nullptr , &proto);
142
+ block_desc.Proto ()->set_parent_idx (-1 );
143
+ block_desc.Proto ()->set_idx (0 );
144
+ LOG (INFO) << " origin variable size: "
145
+ << argument_->origin_program_desc ->blocks (0 ).vars ().size ();
146
+ LOG (INFO) << " transformed variable size: "
147
+ << block_desc.Proto ()->vars ().size ();
137
148
// copy ops.
138
- for (auto * node : block_node->subgraph ) {
139
- auto * op = block_desc.AppendOp ();
149
+ for (auto * node : block_node->subgraph ) {
150
+ auto * op = block_desc.AppendOp ();
140
151
PADDLE_ENFORCE (!node->pb_msg ().empty ());
141
152
op->Proto ()->ParseFromString (node->pb_msg ());
142
153
}
154
+ *block_desc.Proto ()->mutable_vars () =
155
+ argument_->origin_program_desc ->blocks (0 ).vars ();
156
+ PADDLE_ENFORCE (!block_desc.Proto ()->vars ().empty ());
143
157
CreateTrtEngineOp (node, *argument_->main_dfg , *block_desc.Proto ());
144
- auto * main_block = desc_->mutable_blocks (framework::kRootBlockIndex );
145
- auto * op = main_block->add_ops ();
158
+ auto * main_block = desc_->mutable_blocks (framework::kRootBlockIndex );
159
+ auto * op = main_block->add_ops ();
146
160
PADDLE_ENFORCE (!node->pb_msg ().empty (), " failed to set desc for block" );
147
161
op->ParseFromString (node->pb_msg ());
148
162
}
@@ -151,7 +165,7 @@ namespace {
151
165
class DFG_DebuggerPass : public DFG_GraphvizDrawPass {
152
166
public:
153
167
using Config = DFG_GraphvizDrawPass::Config;
154
- explicit DFG_DebuggerPass (const Config& config)
168
+ explicit DFG_DebuggerPass (const Config & config)
155
169
: DFG_GraphvizDrawPass(config) {}
156
170
157
171
std::string repr () const override { return " dfg-to-fluid-debuger-pass" ; }
@@ -160,7 +174,7 @@ class DFG_DebuggerPass : public DFG_GraphvizDrawPass {
160
174
};
161
175
} // namespace
162
176
163
- Pass* DataFlowGraphToFluidPass::CreateGraphvizDebugerPass () const {
177
+ Pass * DataFlowGraphToFluidPass::CreateGraphvizDebugerPass () const {
164
178
return new DFG_DebuggerPass (DFG_GraphvizDrawPass::Config (
165
179
FLAGS_inference_analysis_graphviz_log_root,
166
180
" data_flow_graph_to_fluid_graphviz_debugger" ));
0 commit comments