Skip to content

Commit 470335e

Browse files
authored
Merge pull request #12786 from NHZlX/add_batch_norm_trt_converter
Add batch norm trt converter
2 parents 3d11d01 + ff052c0 commit 470335e

File tree

4 files changed

+225
-2
lines changed

4 files changed

+225
-2
lines changed

paddle/fluid/inference/tensorrt/convert/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Add TRT tests
22
nv_library(tensorrt_converter
33
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
4-
activation_op.cc softmax_op.cc
4+
batch_norm_op.cc activation_op.cc softmax_op.cc
55
DEPS tensorrt_engine operator scope framework_proto op_registry)
66

77
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
@@ -24,3 +24,6 @@ nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc
2424

2525
nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc
2626
DEPS ${FLUID_CORE_MODULES} tensorrt_engine softmax_op SERIAL)
27+
28+
nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc
29+
DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL)
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <math.h>
16+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
17+
18+
namespace paddle {
19+
namespace inference {
20+
namespace tensorrt {
21+
22+
class BatchNormOpConverter : public OpConverter {
23+
public:
24+
void operator()(const framework::proto::OpDesc& op,
25+
const framework::Scope& scope, bool test_mode) override {
26+
LOG(INFO) << "convert a fluid batch norm op to tensorrt batch_norm";
27+
28+
framework::OpDesc op_desc(op, nullptr);
29+
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
30+
PADDLE_ENFORCE_EQ(op_desc.Input("Bias").size(), 1); // Bias is a weight
31+
PADDLE_ENFORCE_EQ(op_desc.Input("Mean").size(), 1); // Mean is a weight
32+
PADDLE_ENFORCE_EQ(op_desc.Input("Scale").size(), 1); // Scale is a weight
33+
PADDLE_ENFORCE_EQ(op_desc.Input("Variance").size(),
34+
1); // Variance is a weight
35+
PADDLE_ENFORCE_EQ(op_desc.Output("Y").size(), 1);
36+
37+
auto* X = engine_->GetITensor(op_desc.Input("X").front());
38+
// Declare weights
39+
auto* Bias_v = scope.FindVar(op_desc.Input("Bias").front());
40+
auto* Mean_v = scope.FindVar(op_desc.Input("Mean").front());
41+
auto* Scale_v = scope.FindVar(op_desc.Input("Scale").front());
42+
auto* Variance_v = scope.FindVar(op_desc.Input("Variance").front());
43+
const float eps = boost::get<float>(op_desc.GetAttr("epsilon"));
44+
45+
PADDLE_ENFORCE_NOT_NULL(Bias_v);
46+
PADDLE_ENFORCE_NOT_NULL(Mean_v);
47+
PADDLE_ENFORCE_NOT_NULL(Scale_v);
48+
PADDLE_ENFORCE_NOT_NULL(Variance_v);
49+
50+
// get tensor
51+
auto* Bias_t = Bias_v->GetMutable<framework::LoDTensor>();
52+
auto* Mean_t = Mean_v->GetMutable<framework::LoDTensor>();
53+
auto* Scale_t = Scale_v->GetMutable<framework::LoDTensor>();
54+
auto* Variance_t = Variance_v->GetMutable<framework::LoDTensor>();
55+
56+
// create temp tensor for weights
57+
framework::LoDTensor bias_tensor;
58+
framework::LoDTensor mean_tensor;
59+
framework::LoDTensor scale_tensor;
60+
framework::LoDTensor variance_tensor;
61+
62+
bias_tensor.Resize(Bias_t->dims());
63+
mean_tensor.Resize(Mean_t->dims());
64+
scale_tensor.Resize(Scale_t->dims());
65+
variance_tensor.Resize(Variance_t->dims());
66+
67+
platform::CPUPlace cpu_place;
68+
// copy data from gpu to cpu
69+
TensorCopySync((*Bias_t), cpu_place, &bias_tensor);
70+
TensorCopySync((*Mean_t), cpu_place, &mean_tensor);
71+
TensorCopySync((*Scale_t), cpu_place, &scale_tensor);
72+
TensorCopySync((*Variance_t), cpu_place, &variance_tensor);
73+
74+
auto* bias_data = bias_tensor.mutable_data<float>(platform::CPUPlace());
75+
auto* mean_data = mean_tensor.mutable_data<float>(platform::CPUPlace());
76+
auto* scale_data = scale_tensor.mutable_data<float>(platform::CPUPlace());
77+
auto* variance_data =
78+
variance_tensor.mutable_data<float>(platform::CPUPlace());
79+
80+
std::unique_ptr<framework::LoDTensor> combile_scale_tensor(
81+
new framework::LoDTensor());
82+
std::unique_ptr<framework::LoDTensor> combile_bias_tensor(
83+
new framework::LoDTensor());
84+
85+
combile_scale_tensor->Resize(scale_tensor.dims());
86+
combile_bias_tensor->Resize(bias_tensor.dims());
87+
88+
auto* combile_scale_data =
89+
combile_scale_tensor->mutable_data<float>(platform::CPUPlace());
90+
auto* combile_bias_data =
91+
combile_bias_tensor->mutable_data<float>(platform::CPUPlace());
92+
93+
size_t ele_num = combile_scale_tensor->memory_size() / sizeof(float);
94+
95+
for (size_t i = 0; i < ele_num; i++) {
96+
float scale = scale_data[i];
97+
float bias = bias_data[i];
98+
float mean = mean_data[i];
99+
float variance = variance_data[i];
100+
combile_scale_data[i] = scale / sqrtf(variance + eps);
101+
combile_bias_data[i] = bias - mean * combile_scale_data[i];
102+
}
103+
104+
TensorRTEngine::Weight scale_weights{
105+
nvinfer1::DataType::kFLOAT, static_cast<void*>(combile_scale_data),
106+
combile_scale_tensor->memory_size() / sizeof(float)};
107+
TensorRTEngine::Weight shift_weights{
108+
nvinfer1::DataType::kFLOAT, static_cast<void*>(combile_bias_data),
109+
combile_bias_tensor->memory_size() / sizeof(float)};
110+
TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr,
111+
0};
112+
113+
nvinfer1::IScaleLayer* layer =
114+
TRT_ENGINE_ADD_LAYER(engine_, Scale, *const_cast<nvinfer1::ITensor*>(X),
115+
nvinfer1::ScaleMode::kCHANNEL, shift_weights.get(),
116+
scale_weights.get(), power_weights.get());
117+
118+
auto output_name = op_desc.Output("Y").front();
119+
engine_->weight_map[op_desc.Input("Bias").front()] =
120+
std::move(combile_bias_tensor);
121+
engine_->weight_map[op_desc.Input("Scale").front()] =
122+
std::move(combile_scale_tensor);
123+
124+
engine_->SetITensor(output_name, layer->getOutput(0));
125+
126+
if (test_mode) {
127+
engine_->DeclareOutput(output_name);
128+
}
129+
}
130+
};
131+
132+
} // namespace tensorrt
133+
} // namespace inference
134+
} // namespace paddle
135+
136+
REGISTER_TRT_OP_CONVERTER(batch_norm, BatchNormOpConverter);
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <gtest/gtest.h>
16+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
17+
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
18+
19+
namespace paddle {
20+
namespace inference {
21+
namespace tensorrt {
22+
23+
TEST(batch_norm_op, test) {
24+
std::unordered_set<std::string> parameters(
25+
{"batch_norm_scale", "batch_norm_bias", "batch_norm_mean",
26+
"batch_norm_variance"});
27+
framework::Scope scope;
28+
TRTConvertValidation validator(5, parameters, scope, 1 << 15);
29+
std::vector<int> param_shape{2};
30+
31+
validator.DeclInputVar("batch_norm_X", nvinfer1::DimsCHW(2, 5, 5));
32+
validator.DeclParamVar("batch_norm_scale", param_shape);
33+
validator.DeclParamVar("batch_norm_bias", param_shape);
34+
validator.DeclParamVar("batch_norm_mean", param_shape);
35+
validator.DeclParamVar("batch_norm_variance", param_shape);
36+
validator.DeclOutputVar("batch_norm_Y", nvinfer1::DimsCHW(2, 5, 5));
37+
validator.DeclOutputVar("batch_norm_save_mean", param_shape);
38+
validator.DeclOutputVar("batch_norm_save_variance", param_shape);
39+
40+
// Prepare Op description
41+
framework::OpDesc desc;
42+
43+
desc.SetType("batch_norm");
44+
desc.SetInput("X", {"batch_norm_X"});
45+
desc.SetInput("Scale", {"batch_norm_scale"});
46+
desc.SetInput("Bias", {"batch_norm_bias"});
47+
desc.SetInput("Mean", {"batch_norm_mean"});
48+
desc.SetInput("Variance", {"batch_norm_variance"});
49+
desc.SetOutput("Y", {"batch_norm_Y"});
50+
desc.SetOutput("MeanOut", {"batch_norm_mean"});
51+
desc.SetOutput("VarianceOut", {"batch_norm_variance"});
52+
desc.SetOutput("SavedMean", {"batch_norm_save_mean"});
53+
desc.SetOutput("SavedVariance", {"batch_norm_save_variance"});
54+
55+
float eps = 1e-5f;
56+
bool is_test = true;
57+
desc.SetAttr("epsilon", eps);
58+
desc.SetAttr("is_test", is_test);
59+
60+
validator.SetOp(*desc.Proto());
61+
62+
std::unordered_set<std::string> neglected_output = {
63+
"batch_norm_save_mean", "batch_norm_save_variance", "batch_norm_mean",
64+
"batch_norm_variance"};
65+
validator.Execute(3, neglected_output);
66+
}
67+
68+
} // namespace tensorrt
69+
} // namespace inference
70+
} // namespace paddle
71+
USE_OP(batch_norm);

paddle/fluid/inference/tensorrt/convert/ut_helper.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,19 @@ class TRTConvertValidation {
9898
engine_->DeclareInput(name, nvinfer1::DataType::kFLOAT, dims);
9999
}
100100

101+
void DeclParamVar(const std::string& name, const std::vector<int> dim_vec) {
102+
DeclVar(name, dim_vec);
103+
}
104+
101105
// Declare a parameter varaible in the scope.
102106
void DeclParamVar(const std::string& name, const nvinfer1::Dims& dims) {
103107
DeclVar(name, dims, true);
104108
}
105109

110+
void DeclOutputVar(const std::string& name, const std::vector<int> dim_vec) {
111+
DeclVar(name, dim_vec);
112+
}
113+
106114
void DeclOutputVar(const std::string& name, const nvinfer1::Dims& dims) {
107115
DeclVar(name, dims);
108116
}
@@ -155,7 +163,11 @@ class TRTConvertValidation {
155163
}
156164
}
157165

158-
void Execute(int batch_size) {
166+
// We use the set 'neglected_output' here, because some Ops like batch norm,
167+
// the outputs specified in the op des are only used during training,
168+
// so we should neglect those output during inference.
169+
void Execute(int batch_size,
170+
std::unordered_set<std::string> neglected_output = {}) {
159171
// Execute Fluid Op
160172
PADDLE_ENFORCE_LE(batch_size, max_batch_size_);
161173
platform::CUDAPlace place;
@@ -168,6 +180,7 @@ class TRTConvertValidation {
168180
ASSERT_FALSE(op_desc_->OutputArgumentNames().empty());
169181
const size_t output_space_size = 3000;
170182
for (const auto& output : op_desc_->OutputArgumentNames()) {
183+
if (neglected_output.count(output)) continue;
171184
std::vector<float> fluid_out;
172185
std::vector<float> trt_out(output_space_size);
173186
engine_->GetOutputInCPU(output, &trt_out[0], output_space_size);

0 commit comments

Comments
 (0)