Skip to content

Commit 8bc1c5d

Browse files
authored
Implement the Tensorrt plugin for elementwise op (#14487)
* Initialize the elementwise plugin. * Implement the basic CUDA kernel of elementwise plugin. test=develop
1 parent 7aa3aff commit 8bc1c5d

23 files changed

+500
-166
lines changed

paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
114114
// it is either an OP's input or an OP's output.
115115

116116
auto &subgraph_nodes = *Agent(node).subgraph();
117-
for (size_t index = 0; index < block_desc.OpSize(); index++) {
117+
for (size_t index = 0; index < block_desc.OpSize(); ++index) {
118118
framework::proto::OpDesc *op = block_desc.Op(index)->Proto();
119119
auto correspond_node = subgraph_nodes[index];
120120
PADDLE_ENFORCE_EQ(correspond_node->Name(), op->type());

paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) {
4545
std::unordered_set<std::string> teller_set(
4646
{"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
4747
"depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
48-
"elementwise_add", "dropout", "split", "prelu", "conv2d_transpose"});
48+
"elementwise_add", "elementwise_mul", "dropout", "split", "prelu",
49+
"conv2d_transpose"});
4950
if (!node->IsOp()) return false;
5051

5152
if (teller_set.count(node->Op()->Type())) {

paddle/fluid/inference/tensorrt/convert/CMakeLists.txt

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# Add TRT tests
22
nv_library(tensorrt_converter
3-
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
4-
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
5-
pad_op.cc split_op.cc prelu_op.cc
6-
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
3+
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
4+
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
5+
pad_op.cc split_op.cc prelu_op.cc
6+
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
77

88
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
99
${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_converter)
@@ -20,7 +20,8 @@ nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc
2020
nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc
2121
DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine pool_op SERIAL)
2222
nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc
23-
DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine elementwise_add_op SERIAL)
23+
DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_plugin
24+
elementwise_add_op elementwise_mul_op SERIAL)
2425
nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc
2526
DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine softmax_op SERIAL)
2627
nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc
@@ -33,7 +34,7 @@ nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc
3334
DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine pad_op SERIAL)
3435
nv_test(test_trt_split_op SRCS test_split_op.cc split_op.cc
3536
DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_plugin
36-
split_op concat_op SERIAL)
37+
split_op concat_op SERIAL)
3738
nv_test(test_trt_prelu_op SRCS test_prelu_op.cc prelu_op.cc
3839
DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_plugin
3940
prelu_op SERIAL)

paddle/fluid/inference/tensorrt/convert/elementwise_op.cc

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
55
You may obtain a copy of the License at
66
7-
http://www.apache.org/licenses/LICENSE-2.0
7+
http://www.apache.org/licenses/LICENSE-2.0
88
99
Unless required by applicable law or agreed to in writing, software
1010
distributed under the License is distributed on an "AS IS" BASIS,
@@ -13,11 +13,25 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
16+
#include "paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h"
1617

1718
namespace paddle {
1819
namespace inference {
1920
namespace tensorrt {
2021

22+
static bool CheckDims(const nvinfer1::Dims& dims_x,
23+
const nvinfer1::Dims& dims_y) {
24+
if (dims_x.nbDims != dims_y.nbDims) {
25+
return false;
26+
}
27+
for (int i = 0; i < dims_x.nbDims; i++) {
28+
if (dims_x.d[i] != dims_y.d[i]) {
29+
return false;
30+
}
31+
}
32+
return true;
33+
}
34+
2135
class ElementwiseWeightOpConverter : public OpConverter {
2236
public:
2337
ElementwiseWeightOpConverter() {}
@@ -26,7 +40,7 @@ class ElementwiseWeightOpConverter : public OpConverter {
2640
// Here the two nullptr looks strange, that's because the
2741
// framework::OpDesc's constructor is strange.
2842
framework::OpDesc op_desc(op, nullptr);
29-
VLOG(3) << "convert a fluid elementwise op to tensorrt IScaleLayer";
43+
VLOG(3) << "Convert a fluid elementwise op to TensorRT IScaleLayer";
3044

3145
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
3246
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1); // Y is a weight
@@ -106,10 +120,12 @@ class ElementwiseTensorOpConverter : public OpConverter {
106120
ElementwiseTensorOpConverter() {}
107121
void operator()(const framework::proto::OpDesc& op,
108122
const framework::Scope& scope, bool test_mode) override {
123+
auto op_pair = ops.find(op_type_);
124+
PADDLE_ENFORCE(op_pair != ops.end(), "Wrong elementwise op type!");
125+
109126
// Here the two nullptr looks strange, that's because the
110127
// framework::OpDesc's constructor is strange.
111128
framework::OpDesc op_desc(op, nullptr);
112-
VLOG(3) << "convert a fluid elementwise op to tensorrt IScaleLayer";
113129

114130
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
115131
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1); // Y is a weight
@@ -120,29 +136,35 @@ class ElementwiseTensorOpConverter : public OpConverter {
120136
nvinfer1::Dims dims_x = X->getDimensions();
121137
nvinfer1::Dims dims_y = Y->getDimensions();
122138

123-
// The two input tensor should have the same dims
124-
PADDLE_ENFORCE(dims_x.nbDims >= 3);
125-
if (dims_x.nbDims == dims_y.nbDims) {
126-
for (int i = 0; i < dims_x.nbDims; i++) {
127-
if (dims_x.d[i] != dims_y.d[i])
128-
PADDLE_THROW("TensorRT unsupported tensor shape for Elementwise op!");
129-
}
130-
} else {
131-
PADDLE_THROW("TensorRT unsupported tensor shape for Elementwise op!");
132-
}
139+
int axis = boost::get<int>(op_desc.GetAttr("axis"));
140+
auto output_name = op_desc.Output("Out")[0];
141+
if (CheckDims(dims_x, dims_y)) {
142+
// The two input tensor should have the same dims
143+
VLOG(3) << "Convert a fluid elementwise op to TensorRT IElementWiseLayer";
133144

134-
auto op_pair = ops.find(op_type_);
135-
if (op_pair == ops.end()) {
136-
PADDLE_THROW("Wrong elementwise op type!");
137-
}
138-
nvinfer1::IElementWiseLayer* layer = TRT_ENGINE_ADD_LAYER(
139-
engine_, ElementWise, *const_cast<nvinfer1::ITensor*>(X),
140-
*const_cast<nvinfer1::ITensor*>(Y), op_pair->second);
145+
nvinfer1::IElementWiseLayer* layer = TRT_ENGINE_ADD_LAYER(
146+
engine_, ElementWise, *const_cast<nvinfer1::ITensor*>(X),
147+
*const_cast<nvinfer1::ITensor*>(Y), op_pair->second);
141148

142-
auto output_name = op_desc.Output("Out")[0];
143-
layer->setName(("elementwise (Output: " + output_name + ")").c_str());
144-
layer->getOutput(0)->setName(output_name.c_str());
145-
engine_->SetITensor(output_name, layer->getOutput(0));
149+
layer->setName(("elementwise (Output: " + output_name + ")").c_str());
150+
layer->getOutput(0)->setName(output_name.c_str());
151+
engine_->SetITensor(output_name, layer->getOutput(0));
152+
} else {
153+
VLOG(3) << "Convert a fluid elementwise op to TensorRT "
154+
"ElementWisePluginLayer";
155+
156+
plugin::ElementWisePlugin* plugin =
157+
new plugin::ElementWisePlugin(op_pair->second, dims_x, dims_y, axis);
158+
plugin->AddInput(X);
159+
plugin->AddInput(Y);
160+
nvinfer1::IPluginLayer* layer = engine_->AddPlugin(
161+
const_cast<nvinfer1::ITensor* const*>(plugin->GetInputs().data()), 2,
162+
reinterpret_cast<plugin::PluginTensorRT*>(plugin));
163+
164+
layer->setName(("elementwise (Output: " + output_name + ")").c_str());
165+
layer->getOutput(0)->setName(output_name.c_str());
166+
engine_->SetITensor(output_name, layer->getOutput(0));
167+
}
146168
if (test_mode) { // the test framework can not determine which is the
147169
// output, so place the declaration inside.
148170
engine_->DeclareOutput(output_name);

paddle/fluid/inference/tensorrt/convert/op_converter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class OpConverter {
6161
// TODO(xingzhaolong): all mul, sub, div
6262
// static std::unordered_set<std::string> add_weight_op_set {"add", "mul",
6363
// "sub", "div"};
64-
static std::unordered_set<std::string> add_weight_op_set{"add"};
64+
static std::unordered_set<std::string> add_weight_op_set{"add", "mul"};
6565
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1UL);
6666
int op_type_len = op_desc.Type().size();
6767
std::string op_type = op_desc.Type().substr(op_type_len - 3, op_type_len);

paddle/fluid/inference/tensorrt/convert/prelu_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class PReluOpConverter : public OpConverter {
5454
TensorRTEngine::Weight alpha_rt(nvinfer1::DataType::kFLOAT,
5555
static_cast<void*>(alpha_data),
5656
alpha_tensor_device->numel());
57-
PReluPlugin* plugin = new PReluPlugin(alpha_rt, mode);
57+
plugin::PReluPlugin* plugin = new plugin::PReluPlugin(alpha_rt, mode);
5858
nvinfer1::IPluginLayer* layer =
5959
engine_->AddPlugin(&input, input_num, plugin);
6060
// keep alpha tensor to avoid release it's memory

paddle/fluid/inference/tensorrt/convert/split_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class SplitOpConverter : public OpConverter {
5050
PADDLE_ENFORCE(output_lengths.size() == output_num);
5151

5252
//
53-
SplitPlugin* plugin = new SplitPlugin(axis, output_lengths);
53+
plugin::SplitPlugin* plugin = new plugin::SplitPlugin(axis, output_lengths);
5454
nvinfer1::IPluginLayer* layer =
5555
engine_->AddPlugin(&input, input_num, plugin);
5656

paddle/fluid/inference/tensorrt/convert/test_elementwise_op.cc

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@ namespace paddle {
2020
namespace inference {
2121
namespace tensorrt {
2222

23-
TEST(elementwise_op, add_weight_test) {
23+
TEST(elementwise_op, add_weight) {
2424
std::unordered_set<std::string> parameters({"elementwise_add-Y"});
2525
framework::Scope scope;
2626
TRTConvertValidation validator(10, parameters, scope, 1 << 15);
2727
validator.DeclInputVar("elementwise_add-X", nvinfer1::DimsCHW(10, 3, 3));
2828
validator.DeclParamVar("elementwise_add-Y", nvinfer1::Dims3(10, 1, 1));
29-
// validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2));
3029
validator.DeclOutputVar("elementwise_add-Out", nvinfer1::DimsCHW(10, 3, 3));
3130

3231
// Prepare Op description
@@ -44,30 +43,65 @@ TEST(elementwise_op, add_weight_test) {
4443
validator.Execute(8);
4544
}
4645

47-
TEST(elementwise_op, add_tensor_test) {
48-
std::unordered_set<std::string> parameters;
49-
framework::Scope scope;
50-
TRTConvertValidation validator(8, parameters, scope, 1 << 15);
51-
validator.DeclInputVar("elementwise_add-X", nvinfer1::DimsCHW(10, 3, 3));
52-
validator.DeclInputVar("elementwise_add-Y", nvinfer1::Dims3(10, 3, 3));
53-
// validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2));
54-
validator.DeclOutputVar("elementwise_add-Out", nvinfer1::DimsCHW(10, 3, 3));
55-
56-
// Prepare Op description
57-
framework::OpDesc desc;
58-
desc.SetType("elementwise_add");
59-
desc.SetInput("X", {"elementwise_add-X"});
60-
desc.SetInput("Y", {"elementwise_add-Y"});
61-
desc.SetOutput("Out", {"elementwise_add-Out"});
62-
63-
// the defalut axis of elementwise op is -1
64-
65-
validator.SetOp(*desc.Proto());
46+
TEST(elementwise_op, native) {
47+
for (std::string type : {"add", "mul"}) {
48+
int batch_size = 8;
49+
std::unordered_set<std::string> parameters;
50+
framework::Scope scope;
51+
TRTConvertValidation validator(batch_size, parameters, scope, 1 << 15);
52+
validator.DeclInputVar("elementwise_" + type + "-X",
53+
nvinfer1::DimsCHW(10, 3, 3));
54+
validator.DeclInputVar("elementwise_" + type + "-Y",
55+
nvinfer1::Dims3(10, 3, 3));
56+
validator.DeclOutputVar("elementwise_" + type + "-Out",
57+
nvinfer1::DimsCHW(10, 3, 3));
58+
59+
// Prepare Op description
60+
framework::OpDesc desc;
61+
desc.SetType("elementwise_" + type);
62+
desc.SetInput("X", {"elementwise_" + type + "-X"});
63+
desc.SetInput("Y", {"elementwise_" + type + "-Y"});
64+
desc.SetOutput("Out", {"elementwise_" + type + "-Out"});
65+
66+
int axis = -1;
67+
desc.SetAttr("axis", axis);
68+
69+
validator.SetOp(*desc.Proto());
70+
validator.Execute(batch_size);
71+
}
72+
}
6673

67-
validator.Execute(8);
74+
TEST(elementwise_op, plugin) {
75+
for (std::string type : {"add", "mul"}) {
76+
int batch_size = 8;
77+
std::unordered_set<std::string> parameters;
78+
framework::Scope scope;
79+
TRTConvertValidation validator(batch_size, parameters, scope, 1 << 15);
80+
validator.DeclInputVar("elementwise_" + type + "-X",
81+
nvinfer1::DimsCHW(10, 3, 3));
82+
validator.DeclInputVar("elementwise_" + type + "-Y",
83+
nvinfer1::Dims3(10, 1, 1));
84+
validator.DeclOutputVar("elementwise_" + type + "-Out",
85+
nvinfer1::DimsCHW(10, 3, 3));
86+
87+
// Prepare Op description
88+
framework::OpDesc desc;
89+
desc.SetType("elementwise_" + type);
90+
desc.SetInput("X", {"elementwise_" + type + "-X"});
91+
desc.SetInput("Y", {"elementwise_" + type + "-Y"});
92+
desc.SetOutput("Out", {"elementwise_" + type + "-Out"});
93+
94+
int axis = -1;
95+
desc.SetAttr("axis", axis);
96+
97+
validator.SetOp(*desc.Proto());
98+
validator.Execute(batch_size);
99+
}
68100
}
69101

70102
} // namespace tensorrt
71103
} // namespace inference
72104
} // namespace paddle
105+
73106
USE_OP(elementwise_add);
107+
USE_OP(elementwise_mul);

paddle/fluid/inference/tensorrt/convert/test_mul_op.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
22
3-
Licensed under the Apache License, Version 2.0 (the "License");
4-
you may not use this file except in compliance with the License.
5-
You may obtain a copy of the License at
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
66
7-
http://www.apache.org/licenses/LICENSE-2.0
7+
http://www.apache.org/licenses/LICENSE-2.0
88
9-
Unless required by applicable law or agreed to in writing, software
10-
distributed under the License is distributed on an "AS IS" BASIS,
11-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
See the License for the specific language governing permissions and
13-
limitations under the License. */
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
1414

1515
#include <gtest/gtest.h>
1616
#include "paddle/fluid/framework/op_registry.h"

paddle/fluid/inference/tensorrt/convert/ut_helper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
55
You may obtain a copy of the License at
66
7-
http://www.apache.org/licenses/LICENSE-2.0
7+
http://www.apache.org/licenses/LICENSE-2.0
88
99
Unless required by applicable law or agreed to in writing, software
1010
distributed under the License is distributed on an "AS IS" BASIS,

0 commit comments

Comments
 (0)