Skip to content

Commit 77ac30e

Browse files
authored
Merge pull request #14386 from NHZlX/add_trt_plugin
add plugin support for paddle-trt
2 parents 8cfda7e + 15bdb7e commit 77ac30e

File tree

15 files changed

+557
-5
lines changed

15 files changed

+557
-5
lines changed

paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) {
4545
std::unordered_set<std::string> teller_set(
4646
{"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
4747
"depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
48-
"elementwise_add", "dropout"});
48+
"elementwise_add", "dropout", "split"});
4949
if (!node->IsOp()) return false;
5050

5151
if (teller_set.count(node->Op()->Type())) {

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,4 +548,5 @@ USE_TRT_CONVERTER(batch_norm);
548548
USE_TRT_CONVERTER(concat);
549549
USE_TRT_CONVERTER(dropout);
550550
USE_TRT_CONVERTER(pad);
551+
USE_TRT_CONVERTER(split);
551552
#endif
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto device_context)
22
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
33
nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine)
4+
add_subdirectory(plugin)
45
add_subdirectory(convert)

paddle/fluid/inference/tensorrt/convert/CMakeLists.txt

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# Add TRT tests
22
nv_library(tensorrt_converter
33
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
4-
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc
5-
DEPS tensorrt_engine operator scope framework_proto op_registry)
4+
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
5+
pad_op.cc split_op.cc
6+
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
67

78
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
89
${FLUID_CORE_MODULES} tensorrt_engine tensorrt_converter)
@@ -28,6 +29,8 @@ nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc
2829
DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL)
2930
nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc
3031
DEPS ${FLUID_CORE_MODULES} tensorrt_engine dropout_op SERIAL)
31-
3232
nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc
3333
DEPS ${FLUID_CORE_MODULES} tensorrt_engine pad_op SERIAL)
34+
nv_test(test_trt_split_op SRCS test_split_op.cc split_op.cc
35+
DEPS ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_plugin
36+
split_op concat_op SERIAL)

paddle/fluid/inference/tensorrt/convert/concat_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace inference {
1919
namespace tensorrt {
2020

2121
/*
22-
* MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights.
22+
* ConcatOp
2323
*/
2424
class ConcatOpConverter : public OpConverter {
2525
public:
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
16+
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
17+
18+
namespace paddle {
19+
namespace inference {
20+
namespace tensorrt {
21+
22+
/*
23+
* SplitOp.
24+
*/
25+
class SplitOpConverter : public OpConverter {
26+
public:
27+
void operator()(const framework::proto::OpDesc& op,
28+
const framework::Scope& scope, bool test_mode) override {
29+
VLOG(40) << "convert a fluid split op to tensorrt split layer";
30+
31+
framework::OpDesc op_desc(op, nullptr);
32+
// Declare inputs
33+
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
34+
auto input_dims = input->getDimensions();
35+
int input_num = op_desc.Input("X").size();
36+
size_t output_num = op_desc.Output("Out").size();
37+
38+
// Get Attrs
39+
PADDLE_ENFORCE(input_num == 1);
40+
int axis = boost::get<int>(op_desc.GetAttr("axis"));
41+
std::vector<int> output_lengths =
42+
boost::get<std::vector<int>>(op_desc.GetAttr("sections"));
43+
PADDLE_ENFORCE(axis != 0);
44+
if (axis < 0) {
45+
axis += input_dims.nbDims;
46+
} else {
47+
axis -= 1;
48+
}
49+
50+
PADDLE_ENFORCE(output_lengths.size() == output_num);
51+
52+
//
53+
SplitPlugin* plugin = new SplitPlugin(axis, output_lengths);
54+
nvinfer1::IPluginLayer* layer =
55+
engine_->AddPlugin(&input, input_num, plugin);
56+
57+
std::string layer_name = "split (Output: ";
58+
for (size_t i = 0; i < output_num; i++) {
59+
auto output_name = op_desc.Output("Out")[i];
60+
layer->getOutput(i)->setName(output_name.c_str());
61+
engine_->SetITensor(output_name, layer->getOutput(i));
62+
layer_name += output_name;
63+
if (test_mode) {
64+
engine_->DeclareOutput(output_name);
65+
}
66+
}
67+
layer->setName((layer_name + ")").c_str());
68+
}
69+
};
70+
71+
} // namespace tensorrt
72+
} // namespace inference
73+
} // namespace paddle
74+
75+
REGISTER_TRT_OP_CONVERTER(split, SplitOpConverter);
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <gtest/gtest.h>
16+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
17+
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
18+
19+
namespace paddle {
20+
namespace inference {
21+
namespace tensorrt {
22+
23+
TEST(split_op, test) {
24+
std::unordered_set<std::string> parameters({""});
25+
framework::Scope scope;
26+
TRTConvertValidation validator(10, parameters, scope, 1000);
27+
validator.DeclInputVar("split_input", nvinfer1::DimsCHW(3, 2, 2));
28+
validator.DeclOutputVar("split_out1", nvinfer1::DimsCHW(2, 2, 2));
29+
validator.DeclOutputVar("split_out2", nvinfer1::DimsCHW(1, 2, 2));
30+
31+
// Prepare Op description
32+
framework::OpDesc desc;
33+
desc.SetType("split");
34+
desc.SetInput("X", {"split_input"});
35+
desc.SetOutput("Out", {"split_out1", "split_out2"});
36+
37+
int num = 0;
38+
int axis = 1;
39+
std::vector<int> output_lengths = {2, 1};
40+
desc.SetAttr("axis", axis);
41+
desc.SetAttr("num", num);
42+
desc.SetAttr("sections", output_lengths);
43+
44+
validator.SetOp(*desc.Proto());
45+
46+
validator.Execute(1);
47+
}
48+
49+
} // namespace tensorrt
50+
} // namespace inference
51+
} // namespace paddle
52+
53+
USE_OP(split);

paddle/fluid/inference/tensorrt/engine.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,12 @@ void TensorRTEngine::freshDeviceId() {
255255
cudaSetDevice(device_);
256256
}
257257

258+
nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin(
259+
nvinfer1::ITensor *const *inputs, int nbInputs, PluginTensorRT *plugin) {
260+
owned_plugin_.emplace_back(plugin);
261+
return infer_network_.get()->addPluginExt(inputs, nbInputs, *plugin);
262+
}
263+
258264
} // namespace tensorrt
259265
} // namespace inference
260266
} // namespace paddle

paddle/fluid/inference/tensorrt/engine.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License. */
2222
#include "paddle/fluid/framework/tensor.h"
2323
#include "paddle/fluid/inference/engine.h"
2424
#include "paddle/fluid/inference/tensorrt/helper.h"
25+
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
2526
#include "paddle/fluid/inference/utils/singleton.h"
2627

2728
namespace paddle {
@@ -125,6 +126,8 @@ class TensorRTEngine : public EngineBase {
125126
void SetRuntimeBatch(size_t batch_size);
126127
int GetRuntimeBatch();
127128
int GetDevice() { return device_; }
129+
nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs,
130+
int nbInputs, PluginTensorRT*);
128131

129132
// A pointer to CPU memory is needed of the TRT weight.
130133
// Before TRT runs, fluid loads weight into GPU storage.
@@ -164,8 +167,10 @@ class TensorRTEngine : public EngineBase {
164167
std::unordered_map<std::string /*name*/, size_t /*max size*/> buffer_sizes_;
165168
std::unordered_map<std::string /*name*/, nvinfer1::ITensor* /*ITensor*/>
166169
itensor_map_;
170+
167171
// The specific GPU id that the TensorRTEngine bounded to.
168172
int device_;
173+
std::vector<std::unique_ptr<PluginTensorRT>> owned_plugin_;
169174

170175
// TensorRT related internal members
171176
template <typename T>
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu DEPS enforce)

0 commit comments

Comments
 (0)