Skip to content

Commit c69ae86

Browse files
committed
fix comments
1 parent 8f9e704 commit c69ae86

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node *node) {
8787
}
8888

8989
void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
90-
framework::proto::BlockDesc &block) {
90+
framework::proto::BlockDesc *block) {
9191
static int counter{0};
9292
PADDLE_ENFORCE(node->IsFunctionBlock());
9393
framework::OpDesc desc;
@@ -112,22 +112,33 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
112112
desc.SetType("tensorrt_engine");
113113

114114
std::unordered_map<std::string, std::string> output_name_map;
115-
auto subgraph_nodes = func->subgraph;
116115

117-
for (int index = 0; index < block.ops_size(); index++) {
118-
framework::proto::OpDesc *op = block.mutable_ops(index);
119-
// auto &op = block.mutable_ops(index);
116+
// The following procedure is used to rename all the intermediate
117+
// variables and the output variables of the subgraph.
118+
// Why we do this?
119+
// During the transition from fluid OP to tensorrt OP, we map
120+
// the input and output Tensor(fluid data structure) of fluid OP
121+
// to the correspondin ITensor (trt data structure) through the
122+
// Tensor name. When we set up ITensor for an variable, we must
123+
// ensure that it has not been set before.
124+
// If there is variable in the fluid graph, which is not only the
125+
// input of a OP, but also the output of a Op, there will be problems.
126+
// So we have to rename the variable in the subgraph to make sure
127+
// it is either an OP's input or an OP's output.
128+
129+
auto subgraph_nodes = func->subgraph;
130+
for (int index = 0; index < block->ops_size(); index++) {
131+
framework::proto::OpDesc *op = block->mutable_ops(index);
120132
auto correspond_node = subgraph_nodes[index];
121133
PADDLE_ENFORCE_EQ(correspond_node->name(), op->type());
122134

123135
std::unordered_map<std::string, size_t> var2id;
124136
for (auto *in_var : correspond_node->inlinks) {
125137
var2id[in_var->name()] = in_var->id();
126138
}
127-
// TODO(zhaolong): add comments
139+
// rename for the input variables of op inside subgraph
128140
for (int i = 0; i < op->inputs_size(); i++) {
129141
framework::proto::OpDesc_Var *in_var = op->mutable_inputs(i);
130-
// auto &in_var = op->mutable_inputs(i);
131142
std::vector<std::string> replaced_names;
132143
for (int k = 0; k < in_var->arguments_size(); k++) {
133144
std::string arg_value = in_var->arguments(k);
@@ -148,6 +159,7 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
148159
var2id[out_var->name()] = out_var->id();
149160
}
150161

162+
// rename for the output variables of op inside subgraph
151163
for (int i = 0; i < op->outputs_size(); i++) {
152164
framework::proto::OpDesc_Var *out_var = op->mutable_outputs(i);
153165
std::vector<std::string> replaced_names;
@@ -165,15 +177,18 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
165177
}
166178
}
167179
}
180+
// When tensorrt engine runs at the end of the operation,
181+
// output_mapping help us copy the data from the renamed ITensor
182+
// to Tensor.
168183
std::vector<std::string> output_mapping;
169184
for (auto name : output_names) {
170185
PADDLE_ENFORCE(output_name_map.count(name) != 0);
171186
output_mapping.push_back(output_name_map[name]);
172187
}
173188

174-
PADDLE_ENFORCE(!block.vars().empty(), "the block has no var-desc");
189+
PADDLE_ENFORCE(!block->vars().empty(), "the block has no var-desc");
175190
// Set attrs
176-
SetAttr(desc.Proto(), "subgraph", block.SerializeAsString());
191+
SetAttr(desc.Proto(), "subgraph", block->SerializeAsString());
177192
SetAttr(desc.Proto(), "engine_uniq_key", "trt-" + std::to_string(counter++));
178193
SetAttr(desc.Proto(), "max_batch", FLAGS_tensorrt_max_batchsize);
179194
SetAttr(desc.Proto(), "max_workspace", FLAGS_tensorrt_workspace_size);
@@ -220,7 +235,7 @@ void DataFlowGraphToFluidPass::AddEngineOp(Node *node) {
220235
*block_desc.Proto()->mutable_vars() =
221236
argument_->origin_program_desc->blocks(0).vars();
222237
PADDLE_ENFORCE(!block_desc.Proto()->vars().empty());
223-
CreateTrtEngineOp(node, *argument_->main_dfg, *block_desc.Proto());
238+
CreateTrtEngineOp(node, *argument_->main_dfg, block_desc.Proto());
224239
auto *main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
225240
auto *op = main_block->add_ops();
226241
PADDLE_ENFORCE(!node->pb_msg().empty(), "failed to set desc for block");

0 commit comments

Comments
 (0)