Skip to content

Commit 7a019cd

Browse files
committed
merge develop
2 parents e823ce6 + 46fe9ba commit 7a019cd

File tree

8 files changed

+137
-32
lines changed

8 files changed

+137
-32
lines changed

paddle/fluid/inference/analysis/analyzer.cc

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
namespace paddle {
2626

27-
DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false,
27+
DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, true,
2828
"Enable subgraph to TensorRT engine for acceleration");
2929

3030
DEFINE_string(inference_analysis_graphviz_log_root, "./",
@@ -42,10 +42,19 @@ class DfgPassManagerImpl final : public DfgPassManager {
4242
// TODO(Superjomn) set the key with pass reprs.
4343
AddPass("fluid-to-data-flow-graph", new FluidToDataFlowGraphPass);
4444
if (FLAGS_inference_analysis_enable_tensorrt_subgraph_engine) {
45-
auto trt_teller = [](const Node* node) {
45+
auto trt_teller = [&](const Node* node) {
46+
std::unordered_set<std::string> teller_set(
47+
{"elementwise_add", "mul", "conv2d", "pool2d", "relu"});
4648
if (!node->IsFunction()) return false;
47-
return static_cast<const Function*>(node)->func_type() == "mul";
49+
50+
const auto* func = static_cast<const Function*>(node);
51+
if (teller_set.count(func->func_type()))
52+
return true;
53+
else {
54+
return false;
55+
}
4856
};
57+
4958
AddPass("tensorrt-subgraph-marker",
5059
new TensorRTSubgraphNodeMarkPass(trt_teller));
5160
AddPass("tensorrt-subgraph", new TensorRTSubGraphPass(trt_teller));

paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc

Lines changed: 93 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
namespace paddle {
2424
namespace inference {
2525

26-
DEFINE_int32(tensorrt_max_batchsize, 300, "TensorRT maximum batch size");
26+
DEFINE_int32(tensorrt_max_batchsize, 3, "TensorRT maximum batch size");
2727
DEFINE_int32(tensorrt_workspace_size, 2048, "TensorRT workspace size");
2828

2929
namespace analysis {
@@ -88,34 +88,113 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node *node) {
8888
}
8989

9090
void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
91-
const framework::proto::BlockDesc &block) {
91+
framework::proto::BlockDesc *block) {
9292
static int counter{0};
9393
PADDLE_ENFORCE(node->IsFunctionBlock());
9494
framework::OpDesc desc;
9595
auto *func = static_cast<FunctionBlock *>(node);
9696

9797
// collect inputs
98-
std::vector<std::string> io;
98+
std::unordered_set<std::string> input_names;
9999
for (auto *x : func->inlinks) {
100-
io.push_back(x->name());
100+
input_names.insert(x->name());
101101
}
102-
desc.SetInput("Xs", io);
102+
desc.SetInput(
103+
"Xs", std::vector<std::string>(input_names.begin(), input_names.end()));
103104

104-
// collect outputs
105-
io.clear();
105+
std::unordered_set<std::string> output_names;
106106
for (auto *x : func->outlinks) {
107-
io.push_back(x->name());
107+
output_names.insert(x->name());
108108
}
109-
desc.SetOutput("Ys", io);
109+
110+
std::vector<std::string> output_temp(output_names.begin(),
111+
output_names.end());
112+
desc.SetOutput("Ys", output_temp);
110113
desc.SetType("tensorrt_engine");
111114

112-
PADDLE_ENFORCE(!block.vars().empty(), "the block has no var-desc");
115+
std::unordered_map<std::string, std::string> output_name_map;
116+
117+
// The following procedure is used to rename all the intermediate
118+
// variables and the output variables of the subgraph.
119+
// Why we do this?
120+
// During the transition from fluid OP to tensorrt OP, we map
121+
// the input and output Tensor(fluid data structure) of fluid OP
122+
// to the correspondin ITensor (trt data structure) through the
123+
// Tensor name. When we set up ITensor for an variable, we must
124+
// ensure that it has not been set before.
125+
// If there is variable in the fluid graph, which is not only the
126+
// input of a OP, but also the output of a Op, there will be problems.
127+
// So we have to rename the variable in the subgraph to make sure
128+
// it is either an OP's input or an OP's output.
129+
130+
auto subgraph_nodes = func->subgraph;
131+
for (int index = 0; index < block->ops_size(); index++) {
132+
framework::proto::OpDesc *op = block->mutable_ops(index);
133+
auto correspond_node = subgraph_nodes[index];
134+
PADDLE_ENFORCE_EQ(correspond_node->name(), op->type());
135+
136+
std::unordered_map<std::string, size_t> var2id;
137+
for (auto *in_var : correspond_node->inlinks) {
138+
var2id[in_var->name()] = in_var->id();
139+
}
140+
// rename for the input variables of op inside subgraph
141+
for (int i = 0; i < op->inputs_size(); i++) {
142+
framework::proto::OpDesc_Var *in_var = op->mutable_inputs(i);
143+
std::vector<std::string> replaced_names;
144+
for (int k = 0; k < in_var->arguments_size(); k++) {
145+
std::string arg_value = in_var->arguments(k);
146+
if (input_names.count(arg_value)) {
147+
replaced_names.push_back(arg_value);
148+
} else {
149+
replaced_names.push_back(arg_value +
150+
std::to_string(var2id[arg_value]));
151+
}
152+
}
153+
in_var->clear_arguments();
154+
for (size_t k = 0; k < replaced_names.size(); k++) {
155+
in_var->add_arguments(replaced_names[k]);
156+
}
157+
}
158+
var2id.clear();
159+
for (auto out_var : correspond_node->outlinks) {
160+
var2id[out_var->name()] = out_var->id();
161+
}
162+
163+
// rename for the output variables of op inside subgraph
164+
for (int i = 0; i < op->outputs_size(); i++) {
165+
framework::proto::OpDesc_Var *out_var = op->mutable_outputs(i);
166+
std::vector<std::string> replaced_names;
167+
for (int k = 0; k < out_var->arguments_size(); k++) {
168+
std::string arg_value = out_var->arguments(k);
169+
if (output_names.count(arg_value)) {
170+
output_name_map[arg_value] =
171+
arg_value + std::to_string(var2id[arg_value]);
172+
}
173+
replaced_names.push_back(arg_value + std::to_string(var2id[arg_value]));
174+
}
175+
out_var->clear_arguments();
176+
for (size_t k = 0; k < replaced_names.size(); k++) {
177+
out_var->add_arguments(replaced_names[k]);
178+
}
179+
}
180+
}
181+
// When tensorrt engine runs at the end of the operation,
182+
// output_mapping help us copy the data from the renamed ITensor
183+
// to Tensor.
184+
std::vector<std::string> output_mapping;
185+
for (auto name : output_names) {
186+
PADDLE_ENFORCE(output_name_map.count(name) != 0);
187+
output_mapping.push_back(output_name_map[name]);
188+
}
189+
190+
PADDLE_ENFORCE(!block->vars().empty(), "the block has no var-desc");
113191
// Set attrs
114-
SetAttr(desc.Proto(), "subgraph", block.SerializeAsString());
192+
SetAttr(desc.Proto(), "subgraph", block->SerializeAsString());
115193
SetAttr(desc.Proto(), "engine_uniq_key", "trt-" + std::to_string(counter++));
116194
SetAttr(desc.Proto(), "max_batch", FLAGS_tensorrt_max_batchsize);
117195
SetAttr(desc.Proto(), "max_workspace", FLAGS_tensorrt_workspace_size);
118196
SetAttr(desc.Proto(), "parameters", ExtractParameters(graph.nodes.nodes()));
197+
SetAttr(desc.Proto(), "output_name_mapping", output_mapping);
119198
node->SetPbMsg(desc.Proto()->SerializeAsString());
120199
}
121200

@@ -147,15 +226,17 @@ void DataFlowGraphToFluidPass::AddEngineOp(Node *node) {
147226
LOG(INFO) << "transformed variable size: "
148227
<< block_desc.Proto()->vars().size();
149228
// copy ops.
229+
150230
for (auto *node : block_node->subgraph) {
151231
auto *op = block_desc.AppendOp();
152232
PADDLE_ENFORCE(!node->pb_msg().empty());
153233
op->Proto()->ParseFromString(node->pb_msg());
154234
}
235+
155236
*block_desc.Proto()->mutable_vars() =
156237
argument_->origin_program_desc->blocks(0).vars();
157238
PADDLE_ENFORCE(!block_desc.Proto()->vars().empty());
158-
CreateTrtEngineOp(node, *argument_->main_dfg, *block_desc.Proto());
239+
CreateTrtEngineOp(node, *argument_->main_dfg, block_desc.Proto());
159240
auto *main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
160241
auto *op = main_block->add_ops();
161242
PADDLE_ENFORCE(!node->pb_msg().empty(), "failed to set desc for block");

paddle/fluid/inference/analysis/subgraph_splitter.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) {
7676

7777
std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
7878
std::vector<Node *> marked_nodes;
79-
for (auto &node : GraphTraits<DataFlowGraph>(graph_).nodes()) {
79+
for (auto &node : GraphTraits<DataFlowGraph>(graph_).nodes_in_TS()) {
8080
if (node.attr(kMarkerAttrName).Bool()) {
8181
marked_nodes.push_back(&node);
8282
}

paddle/fluid/inference/tensorrt/convert/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Add TRT tests
22
nv_library(tensorrt_converter
33
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
4+
activation_op.cc
45
DEPS tensorrt_engine operator scope framework_proto op_registry)
56

67
nv_test(test_op_converter SRCS test_op_converter.cc DEPS

paddle/fluid/inference/tensorrt/convert/op_converter.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ class OpConverter {
5555
it = Registry<OpConverter>::Lookup("fc");
5656
}
5757
}
58-
5958
if (op_desc.Type().find("elementwise") != std::string::npos) {
6059
static std::unordered_set<std::string> add_tensor_op_set{
6160
"add", "mul", "sub", "div", "max", "min", "pow"};
@@ -72,6 +71,8 @@ class OpConverter {
7271
"Unsupported elementwise type" + op_type);
7372
it =
7473
Registry<OpConverter>::Lookup("elementwise_" + op_type + "_weight");
74+
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
75+
op_desc.Type());
7576
} else {
7677
PADDLE_ENFORCE(add_tensor_op_set.count(op_type) > 0,
7778
"Unsupported elementwise type" + op_type);

paddle/fluid/operators/tensorrt_engine_op.cc

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,8 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t> &shape) {
5555
"TensorRT' tensor input requires at least 2 dimensions");
5656
PADDLE_ENFORCE_LE(shape.size(), 4UL,
5757
"TensorRT' tensor input requires at most 4 dimensions");
58-
59-
switch (shape.size()) {
60-
case 2:
61-
return nvinfer1::Dims2(1, shape[1]);
62-
case 3:
63-
return nvinfer1::Dims3(1, shape[1], shape[2]);
64-
case 4:
65-
return nvinfer1::Dims4(1, shape[1], shape[2], shape[3]);
66-
default:
67-
return nvinfer1::Dims();
68-
}
69-
return nvinfer1::Dims();
58+
PADDLE_ENFORCE_EQ(shape.size(), 4UL);
59+
return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]);
7060
}
7161

7262
} // namespace
@@ -86,6 +76,9 @@ void TensorRTEngineKernel<DeviceContext, T>::Prepare(
8676
parameters.insert(param);
8777
}
8878

79+
std::vector<std::string> output_maps =
80+
context.Attr<std::vector<std::string>>("output_name_mapping");
81+
8982
// TODO(Superjomn) replace this with a different stream
9083
auto *engine = Singleton<TRT_EngineManager>::Global().Create(
9184
max_batch, max_workspace, nullptr /*engine hold its own stream*/,
@@ -97,6 +90,7 @@ void TensorRTEngineKernel<DeviceContext, T>::Prepare(
9790
// Add inputs
9891
VLOG(4) << "declare inputs";
9992
for (auto &input : context.Inputs("Xs")) {
93+
if (parameters.count(input)) continue;
10094
VLOG(4) << "declare input " << input;
10195
auto *var = block.FindVar(input);
10296
// TensorRT engine need to create parameters. The parameter's description
@@ -122,7 +116,7 @@ void TensorRTEngineKernel<DeviceContext, T>::Prepare(
122116
block_desc, parameters, context.scope(), engine);
123117

124118
// Add outputs
125-
for (auto &output : context.Outputs("Ys")) {
119+
for (auto &output : output_maps) {
126120
engine->DeclareOutput(output);
127121
}
128122

paddle/fluid/operators/tensorrt_engine_op.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,17 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
6666
PADDLE_ENFORCE_LE(FLAGS_tensorrt_engine_batch_size,
6767
context.Attr<int>("max_batch"));
6868

69+
std::vector<std::string> output_maps =
70+
context.Attr<std::vector<std::string>>("output_name_mapping");
71+
72+
auto params = context.Attr<std::vector<std::string>>("parameters");
73+
std::unordered_set<std::string> parameters;
74+
for (const auto& param : params) {
75+
parameters.insert(param);
76+
}
6977
// Convert input tensor from fluid to engine.
7078
for (const auto& x : context.Inputs("Xs")) {
79+
if (parameters.count(x)) continue;
7180
// convert input and copy to TRT engine's buffer
7281
auto& t = inference::analysis::GetFromScope<framework::LoDTensor>(
7382
context.scope(), x);
@@ -82,10 +91,12 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
8291
// Execute the engine.
8392
PADDLE_ENFORCE_GT(FLAGS_tensorrt_engine_batch_size, 0);
8493
engine->Execute(FLAGS_tensorrt_engine_batch_size);
94+
8595
// Convert output tensor from engine to fluid
96+
int output_index = 0;
8697
for (const auto& y : context.Outputs("Ys")) {
8798
// convert output and copy to fluid.
88-
nvinfer1::ITensor* trt_t = engine->GetITensor(y);
99+
nvinfer1::ITensor* trt_t = engine->GetITensor(output_maps[output_index]);
89100
auto dims = trt_t->getDimensions();
90101
// Use the output ITensor's dims to reshape the Fluid Tensor.
91102
std::vector<int> ddim(dims.d, dims.d + dims.nbDims);
@@ -102,14 +113,15 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
102113
// TODO(Superjomn) change this float to dtype size.
103114
auto size = inference::analysis::AccuDims(dims.d, dims.nbDims) *
104115
FLAGS_tensorrt_engine_batch_size;
105-
engine->GetOutputInCPU(y,
116+
engine->GetOutputInCPU(output_maps[output_index],
106117
fluid_t->mutable_data<float>(platform::CPUPlace()),
107118
size * sizeof(float));
108119
//} else {
109120
// engine->GetOutputInGPU(
110121
// y, fluid_t->mutable_data<float>(platform::CUDAPlace()),
111122
// size * sizeof(float));
112123
//}
124+
output_index += 1;
113125
}
114126

115127
cudaStreamSynchronize(*engine->stream());

paddle/fluid/operators/tensorrt_engine_op_test.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ TEST(TensorRTEngineOp, manual) {
103103
SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "a_engine");
104104
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(), "parameters",
105105
std::vector<std::string>({}));
106+
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(),
107+
"output_name_mapping",
108+
std::vector<std::string>({"z0"}));
106109

107110
LOG(INFO) << "create engine op";
108111
auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto());
@@ -196,6 +199,10 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
196199
std::vector<std::string>({"y0", "y1", "y2", "y3"}));
197200
SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "b_engine");
198201

202+
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(),
203+
"output_name_mapping",
204+
std::vector<std::string>({"z3"}));
205+
199206
auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto());
200207

201208
// Execute them.

0 commit comments

Comments
 (0)