@@ -87,7 +87,7 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node *node) {
87
87
}
88
88
89
89
void CreateTrtEngineOp (Node *node, const DataFlowGraph &graph,
90
- framework::proto::BlockDesc & block) {
90
+ framework::proto::BlockDesc * block) {
91
91
static int counter{0 };
92
92
PADDLE_ENFORCE (node->IsFunctionBlock ());
93
93
framework::OpDesc desc;
@@ -112,22 +112,33 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
112
112
desc.SetType (" tensorrt_engine" );
113
113
114
114
std::unordered_map<std::string, std::string> output_name_map;
115
- auto subgraph_nodes = func->subgraph ;
116
115
117
- for (int index = 0 ; index < block.ops_size (); index++) {
118
- framework::proto::OpDesc *op = block.mutable_ops (index);
119
- // auto &op = block.mutable_ops(index);
116
+ // The following procedure is used to rename all the intermediate
117
+ // variables and the output variables of the subgraph.
118
+ // Why we do this?
119
+ // During the transition from fluid OP to tensorrt OP, we map
120
+ // the input and output Tensor(fluid data structure) of fluid OP
121
+ // to the correspondin ITensor (trt data structure) through the
122
+ // Tensor name. When we set up ITensor for an variable, we must
123
+ // ensure that it has not been set before.
124
+ // If there is variable in the fluid graph, which is not only the
125
+ // input of a OP, but also the output of a Op, there will be problems.
126
+ // So we have to rename the variable in the subgraph to make sure
127
+ // it is either an OP's input or an OP's output.
128
+
129
+ auto subgraph_nodes = func->subgraph ;
130
+ for (int index = 0 ; index < block->ops_size (); index++) {
131
+ framework::proto::OpDesc *op = block->mutable_ops (index);
120
132
auto correspond_node = subgraph_nodes[index];
121
133
PADDLE_ENFORCE_EQ (correspond_node->name (), op->type ());
122
134
123
135
std::unordered_map<std::string, size_t > var2id;
124
136
for (auto *in_var : correspond_node->inlinks ) {
125
137
var2id[in_var->name ()] = in_var->id ();
126
138
}
127
- // TODO(zhaolong): add comments
139
+ // rename for the input variables of op inside subgraph
128
140
for (int i = 0 ; i < op->inputs_size (); i++) {
129
141
framework::proto::OpDesc_Var *in_var = op->mutable_inputs (i);
130
- // auto &in_var = op->mutable_inputs(i);
131
142
std::vector<std::string> replaced_names;
132
143
for (int k = 0 ; k < in_var->arguments_size (); k++) {
133
144
std::string arg_value = in_var->arguments (k);
@@ -148,6 +159,7 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
148
159
var2id[out_var->name ()] = out_var->id ();
149
160
}
150
161
162
+ // rename for the output variables of op inside subgraph
151
163
for (int i = 0 ; i < op->outputs_size (); i++) {
152
164
framework::proto::OpDesc_Var *out_var = op->mutable_outputs (i);
153
165
std::vector<std::string> replaced_names;
@@ -165,15 +177,18 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
165
177
}
166
178
}
167
179
}
180
+ // When tensorrt engine runs at the end of the operation,
181
+ // output_mapping help us copy the data from the renamed ITensor
182
+ // to Tensor.
168
183
std::vector<std::string> output_mapping;
169
184
for (auto name : output_names) {
170
185
PADDLE_ENFORCE (output_name_map.count (name) != 0 );
171
186
output_mapping.push_back (output_name_map[name]);
172
187
}
173
188
174
- PADDLE_ENFORCE (!block. vars ().empty (), " the block has no var-desc" );
189
+ PADDLE_ENFORCE (!block-> vars ().empty (), " the block has no var-desc" );
175
190
// Set attrs
176
- SetAttr (desc.Proto (), " subgraph" , block. SerializeAsString ());
191
+ SetAttr (desc.Proto (), " subgraph" , block-> SerializeAsString ());
177
192
SetAttr (desc.Proto (), " engine_uniq_key" , " trt-" + std::to_string (counter++));
178
193
SetAttr (desc.Proto (), " max_batch" , FLAGS_tensorrt_max_batchsize);
179
194
SetAttr (desc.Proto (), " max_workspace" , FLAGS_tensorrt_workspace_size);
@@ -220,7 +235,7 @@ void DataFlowGraphToFluidPass::AddEngineOp(Node *node) {
220
235
*block_desc.Proto ()->mutable_vars () =
221
236
argument_->origin_program_desc ->blocks (0 ).vars ();
222
237
PADDLE_ENFORCE (!block_desc.Proto ()->vars ().empty ());
223
- CreateTrtEngineOp (node, *argument_->main_dfg , * block_desc.Proto ());
238
+ CreateTrtEngineOp (node, *argument_->main_dfg , block_desc.Proto ());
224
239
auto *main_block = desc_->mutable_blocks (framework::kRootBlockIndex );
225
240
auto *op = main_block->add_ops ();
226
241
PADDLE_ENFORCE (!node->pb_msg ().empty (), " failed to set desc for block" );
0 commit comments