Skip to content

Commit 552cdc1

Browse files
authored
Merge pull request #13422 from NHZlX/add_dropout_simoid_trt
Add dropout sigmoid op converter for trt
2 parents 4c48918 + cc4a766 commit 552cdc1

File tree

7 files changed

+202
-16
lines changed

7 files changed

+202
-16
lines changed

paddle/fluid/inference/analysis/analyzer.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,9 @@ class DfgPassManagerImpl final : public DfgPassManager {
6969
if (FLAGS_IA_enable_tensorrt_subgraph_engine) {
7070
auto trt_teller = [&](const Node* node) {
7171
std::unordered_set<std::string> teller_set(
72-
{"elementwise_add", "mul", "conv2d", "pool2d", "relu", "softmax",
73-
"depthwise_conv2d", "batch_norm", "concat"});
72+
{"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
73+
"depthwise_conv2d", "batch_norm", "concat", "tanh",
74+
"elementwise_add", "dropout"});
7475
if (!node->IsFunction()) return false;
7576

7677
const auto* func = static_cast<const Function*>(node);

paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,21 @@ CreatePaddlePredictor<TensorRTConfig, PaddleEngineKind::kAutoMixedTensorRT>(
153153
} // namespace paddle
154154

155155
USE_TRT_CONVERTER(elementwise_add_weight);
156+
USE_TRT_CONVERTER(elementwise_add_tensor);
157+
USE_TRT_CONVERTER(elementwise_sub_tensor);
158+
USE_TRT_CONVERTER(elementwise_div_tensor);
159+
USE_TRT_CONVERTER(elementwise_mul_tensor);
160+
USE_TRT_CONVERTER(elementwise_max_tensor);
161+
USE_TRT_CONVERTER(elementwise_min_tensor);
162+
USE_TRT_CONVERTER(elementwise_pow_tensor);
156163
USE_TRT_CONVERTER(mul);
157164
USE_TRT_CONVERTER(conv2d);
158165
USE_TRT_CONVERTER(relu);
166+
USE_TRT_CONVERTER(sigmoid);
167+
USE_TRT_CONVERTER(tanh);
159168
USE_TRT_CONVERTER(fc);
160169
USE_TRT_CONVERTER(pool2d);
161170
USE_TRT_CONVERTER(softmax);
162171
USE_TRT_CONVERTER(batch_norm);
163172
USE_TRT_CONVERTER(concat);
173+
USE_TRT_CONVERTER(dropout);

paddle/fluid/inference/tensorrt/convert/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Add TRT tests
22
nv_library(tensorrt_converter
33
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
4-
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc
4+
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
55
DEPS tensorrt_engine operator scope framework_proto op_registry)
66

77
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
@@ -24,6 +24,8 @@ nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc
2424
DEPS ${FLUID_CORE_MODULES} tensorrt_engine softmax_op SERIAL)
2525
nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc
2626
DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL)
27-
2827
nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc
2928
DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL)
29+
30+
nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc
31+
DEPS ${FLUID_CORE_MODULES} tensorrt_engine dropout_op SERIAL)

paddle/fluid/inference/tensorrt/convert/activation_op.cc

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,34 +19,70 @@ namespace paddle {
1919
namespace inference {
2020
namespace tensorrt {
2121

22-
class ReluOpConverter : public OpConverter {
22+
class ActivationOpConverter : public OpConverter {
2323
public:
24-
ReluOpConverter() {}
24+
ActivationOpConverter() {}
2525
void operator()(const framework::proto::OpDesc& op,
2626
const framework::Scope& scope, bool test_mode) override {
2727
// Here the two nullptr looks strange, that's because the
2828
// framework::OpDesc's constructor is strange.
2929
framework::OpDesc op_desc(op, nullptr);
30-
LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose "
31-
"type is Relu";
30+
LOG(INFO)
31+
<< "convert a fluid Activation op to tensorrt activation layer whose "
32+
"type is "
33+
<< op_type_;
3234
const nvinfer1::ITensor* input_tensor =
3335
engine_->GetITensor(op_desc.Input("X")[0]);
36+
37+
auto op_pair = ops.find(op_type_);
38+
if (op_pair == ops.end()) {
39+
PADDLE_THROW("Wrong activation op type!");
40+
}
41+
3442
nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER(
3543
engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor),
36-
nvinfer1::ActivationType::kRELU);
44+
op_pair->second);
3745
auto output_name = op_desc.Output("Out")[0];
38-
layer->setName(("relu (Output: " + output_name + ")").c_str());
46+
layer->setName((op_type_ + " (Output: " + output_name + ")").c_str());
3947
layer->getOutput(0)->setName(output_name.c_str());
4048
engine_->SetITensor(output_name, layer->getOutput(0));
4149
if (test_mode) { // the test framework can not determine which is the
4250
// output, so place the declaration inside.
4351
engine_->DeclareOutput(output_name);
4452
}
4553
}
54+
55+
protected:
56+
std::string op_type_;
57+
static const std::unordered_map<std::string, nvinfer1::ActivationType> ops;
58+
};
59+
60+
const std::unordered_map<std::string, nvinfer1::ActivationType>
61+
ActivationOpConverter::ops = {
62+
{"relu", nvinfer1::ActivationType::kRELU},
63+
{"sigmoid", nvinfer1::ActivationType::kSIGMOID},
64+
{"tanh", nvinfer1::ActivationType::kTANH},
65+
};
66+
67+
class ReluOpConverter : public ActivationOpConverter {
68+
public:
69+
ReluOpConverter() { op_type_ = "relu"; }
70+
};
71+
72+
class SigmoidOpConverter : public ActivationOpConverter {
73+
public:
74+
SigmoidOpConverter() { op_type_ = "sigmoid"; }
75+
};
76+
77+
class TanhOpConverter : public ActivationOpConverter {
78+
public:
79+
TanhOpConverter() { op_type_ = "tanh"; }
4680
};
4781

4882
} // namespace tensorrt
4983
} // namespace inference
5084
} // namespace paddle
5185

5286
REGISTER_TRT_OP_CONVERTER(relu, ReluOpConverter);
87+
REGISTER_TRT_OP_CONVERTER(sigmoid, SigmoidOpConverter);
88+
REGISTER_TRT_OP_CONVERTER(tanh, TanhOpConverter);
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
16+
17+
namespace paddle {
18+
namespace inference {
19+
namespace tensorrt {
20+
21+
/*
22+
* DropoutOp. This Layer doesn't has weights.
23+
*/
24+
class DropoutOpConverter : public OpConverter {
25+
public:
26+
void operator()(const framework::proto::OpDesc& op,
27+
const framework::Scope& scope, bool test_mode) override {
28+
VLOG(4) << "convert a fluid dropout op to tensorrt dropout layer";
29+
framework::OpDesc op_desc(op, nullptr);
30+
// Declare inputs
31+
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
32+
float dropout_prob = boost::get<float>(op_desc.GetAttr("dropout_prob"));
33+
34+
platform::CPUPlace cpu_place;
35+
std::unique_ptr<framework::LoDTensor> weight_tensor(
36+
new framework::LoDTensor());
37+
weight_tensor->Resize(framework::make_ddim({1}));
38+
auto* weight_data =
39+
weight_tensor->mutable_data<float>(platform::CPUPlace());
40+
weight_data[0] = 1 - dropout_prob;
41+
42+
TensorRTEngine::Weight scale_weights{
43+
nvinfer1::DataType::kFLOAT, static_cast<void*>(weight_data),
44+
weight_tensor->memory_size() / sizeof(float)};
45+
TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT, nullptr,
46+
0};
47+
TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr,
48+
0};
49+
50+
auto* layer = TRT_ENGINE_ADD_LAYER(
51+
engine_, Scale, *const_cast<nvinfer1::ITensor*>(input1),
52+
nvinfer1::ScaleMode::kUNIFORM, shift_weights.get(), scale_weights.get(),
53+
power_weights.get());
54+
55+
engine_->weight_map[op_desc.Output("Out").front() + "_dropout"] =
56+
std::move(weight_tensor);
57+
auto output_name = op_desc.Output("Out")[0];
58+
layer->setName(("dropout (Output: " + output_name + ")").c_str());
59+
engine_->SetITensor(output_name, layer->getOutput(0));
60+
if (test_mode) {
61+
engine_->DeclareOutput(output_name);
62+
}
63+
}
64+
};
65+
66+
} // namespace tensorrt
67+
} // namespace inference
68+
} // namespace paddle
69+
70+
USE_OP(dropout);
71+
REGISTER_TRT_OP_CONVERTER(dropout, DropoutOpConverter);

paddle/fluid/inference/tensorrt/convert/test_activation_op.cc

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,18 @@ namespace paddle {
2020
namespace inference {
2121
namespace tensorrt {
2222

23-
TEST(ReluOpConverter, main) {
23+
void test_activation(std::string act_type) {
2424
framework::Scope scope;
2525
std::unordered_set<std::string> parameters;
2626
TRTConvertValidation validator(10, parameters, scope, 1000);
27-
validator.DeclInputVar("relu-X", nvinfer1::Dims2(10, 6));
28-
validator.DeclOutputVar("relu-Out", nvinfer1::Dims2(10, 6));
27+
validator.DeclInputVar("act-X", nvinfer1::Dims2(10, 6));
28+
validator.DeclOutputVar("act-Out", nvinfer1::Dims2(10, 6));
2929

3030
// Prepare Op description
3131
framework::OpDesc desc;
32-
desc.SetType("relu");
33-
desc.SetInput("X", {"relu-X"});
34-
desc.SetOutput("Out", {"relu-Out"});
32+
desc.SetType(act_type);
33+
desc.SetInput("X", {"act-X"});
34+
desc.SetOutput("Out", {"act-Out"});
3535

3636
LOG(INFO) << "set OP";
3737
validator.SetOp(*desc.Proto());
@@ -40,8 +40,16 @@ TEST(ReluOpConverter, main) {
4040
validator.Execute(5);
4141
}
4242

43+
TEST(ReluOpConverter, main) { test_activation("relu"); }
44+
45+
TEST(SigmoidOpConverter, main) { test_activation("sigmoid"); }
46+
47+
TEST(TanhOpConverter, main) { test_activation("tanh"); }
48+
4349
} // namespace tensorrt
4450
} // namespace inference
4551
} // namespace paddle
4652

4753
USE_OP(relu);
54+
USE_OP(sigmoid);
55+
USE_OP(tanh);
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
#include <gtest/gtest.h>
15+
#include "paddle/fluid/framework/op_registry.h"
16+
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
17+
18+
namespace paddle {
19+
namespace inference {
20+
namespace tensorrt {
21+
22+
TEST(DropoutOpConverter, main) {
23+
framework::Scope scope;
24+
std::unordered_set<std::string> parameters;
25+
TRTConvertValidation validator(8, parameters, scope, 1000);
26+
27+
std::vector<int> tensor_shape{8, 10};
28+
validator.DeclInputVar("dropout-X", tensor_shape,
29+
nvinfer1::DimsCHW(10, 1, 1));
30+
validator.DeclOutputVar("dropout-Out", nvinfer1::DimsCHW(10, 1, 1));
31+
validator.DeclOutputVar("mask-Out", nvinfer1::DimsCHW(10, 1, 1));
32+
33+
// Prepare Op description
34+
framework::OpDesc desc;
35+
int is_test = 1;
36+
float dropout_prob = 0.4;
37+
38+
desc.SetType("dropout");
39+
desc.SetInput("X", {"dropout-X"});
40+
desc.SetOutput("Mask", {"mask-Out"});
41+
desc.SetOutput("Out", {"dropout-Out"});
42+
desc.SetAttr("is_test", is_test);
43+
desc.SetAttr("dropout_prob", dropout_prob);
44+
45+
LOG(INFO) << "set OP";
46+
validator.SetOp(*desc.Proto());
47+
LOG(INFO) << "execute";
48+
49+
std::unordered_set<std::string> neglected_output = {"mask-Out"};
50+
51+
validator.Execute(8, neglected_output);
52+
}
53+
54+
} // namespace tensorrt
55+
} // namespace inference
56+
} // namespace paddle
57+
58+
USE_OP(dropout);

0 commit comments

Comments
 (0)