Skip to content

Commit fe00d32

Browse files
author
Pei Yang
authored
[Paddle-TRT] support group_norm (#31040) (#31188)
1 parent 011a6a5 commit fe00d32

File tree

6 files changed

+221
-7
lines changed

6 files changed

+221
-7
lines changed

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,7 @@ USE_TRT_CONVERTER(conv2d_transpose);
11731173
USE_TRT_CONVERTER(leaky_relu);
11741174
USE_TRT_CONVERTER(shuffle_channel);
11751175
USE_TRT_CONVERTER(swish);
1176+
USE_TRT_CONVERTER(group_norm);
11761177
USE_TRT_CONVERTER(instance_norm);
11771178
USE_TRT_CONVERTER(layer_norm);
11781179
USE_TRT_CONVERTER(gelu);

paddle/fluid/inference/tensorrt/convert/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Add TRT tests
22
nv_library(tensorrt_converter
33
SRCS matmul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
4-
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
4+
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc group_norm_op.cc
55
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
66
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc
77
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc

paddle/fluid/inference/tensorrt/convert/concat_op.cc

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class ConcatOpConverter : public OpConverter {
3434
public:
3535
void operator()(const framework::proto::OpDesc& op,
3636
const framework::Scope& scope, bool test_mode) override {
37-
VLOG(3) << "convert a fluid mul op to tensorrt mul layer without bias";
37+
VLOG(3) << "convert a paddle concat op to tensorrt concat layer";
3838

3939
framework::OpDesc op_desc(op, nullptr);
4040
// Declare inputs
@@ -43,11 +43,6 @@ class ConcatOpConverter : public OpConverter {
4343
itensors.push_back(engine_->GetITensor(input_name));
4444
}
4545
int axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis"));
46-
PADDLE_ENFORCE_GT(axis, 0, platform::errors::InvalidArgument(
47-
"The axis attr of Concat"
48-
" op should be larger than 0 for trt. "
49-
"But received %d.",
50-
axis));
5146

5247
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Concatenation, itensors.data(),
5348
itensors.size());
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include <vector>
13+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
14+
15+
namespace paddle {
16+
namespace framework {
17+
class Scope;
18+
namespace proto {
19+
class OpDesc;
20+
} // namespace proto
21+
} // namespace framework
22+
} // namespace paddle
23+
24+
namespace paddle {
25+
namespace inference {
26+
namespace tensorrt {
27+
28+
class GroupNormOpConverter : public OpConverter {
29+
public:
30+
void operator()(const framework::proto::OpDesc& op,
31+
const framework::Scope& scope, bool test_mode) override {
32+
VLOG(3) << "convert a fluid group_norm op";
33+
34+
framework::OpDesc op_desc(op, nullptr);
35+
36+
auto* input_itensor = engine_->GetITensor(op_desc.Input("X").front());
37+
38+
int groups = BOOST_GET_CONST(int, op_desc.GetAttr("groups"));
39+
float epsilon = BOOST_GET_CONST(float, op_desc.GetAttr("epsilon"));
40+
41+
std::string scale_name = op_desc.Input("Scale").front();
42+
std::string bias_name = op_desc.Input("Bias").front();
43+
44+
// get the presistable var's data
45+
auto get_persistable_data = [&](const std::string& var_name,
46+
framework::DDim* dims) -> float* {
47+
auto* temp_var = scope.FindVar(var_name);
48+
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
49+
(*dims) = temp_tensor->dims();
50+
51+
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false);
52+
return temp_data;
53+
};
54+
55+
framework::DDim scale_dims;
56+
framework::DDim bias_dims;
57+
float* scale_data = get_persistable_data(scale_name, &scale_dims);
58+
float* bias_data = get_persistable_data(bias_name, &bias_dims);
59+
60+
int64_t scale_numel = framework::product(scale_dims);
61+
int64_t bias_numel = framework::product(bias_dims);
62+
63+
TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT,
64+
static_cast<void*>(scale_data),
65+
static_cast<size_t>(scale_numel)};
66+
TensorRTEngine::Weight bias_weights{nvinfer1::DataType::kFLOAT,
67+
static_cast<void*>(bias_data),
68+
static_cast<size_t>(bias_numel)};
69+
70+
nvinfer1::Dims scale_nv_dims;
71+
nvinfer1::Dims bias_nv_dims;
72+
scale_nv_dims.nbDims = scale_dims.size();
73+
bias_nv_dims.nbDims = bias_dims.size();
74+
for (int i = 0; i < scale_dims.size(); i++) {
75+
scale_nv_dims.d[i] = scale_dims.at(i);
76+
}
77+
for (int i = 0; i < bias_dims.size(); i++) {
78+
bias_nv_dims.d[i] = bias_dims.at(i);
79+
}
80+
81+
auto* scale_layer = TRT_ENGINE_ADD_LAYER(engine_, Constant, scale_nv_dims,
82+
scale_weights.get());
83+
auto* bias_layer = TRT_ENGINE_ADD_LAYER(engine_, Constant, bias_nv_dims,
84+
bias_weights.get());
85+
86+
std::vector<nvinfer1::ITensor*> plugin_inputs;
87+
plugin_inputs.emplace_back(input_itensor);
88+
plugin_inputs.emplace_back(scale_layer->getOutput(0));
89+
plugin_inputs.emplace_back(bias_layer->getOutput(0));
90+
91+
const std::vector<nvinfer1::PluginField> fields{
92+
{"eps", &epsilon, nvinfer1::PluginFieldType::kFLOAT32, 1},
93+
{"num_groups", &groups, nvinfer1::PluginFieldType::kINT32, 1},
94+
};
95+
96+
nvinfer1::PluginFieldCollection* plugin_collections =
97+
static_cast<nvinfer1::PluginFieldCollection*>(
98+
malloc(sizeof(*plugin_collections) +
99+
fields.size() * sizeof(nvinfer1::PluginField)));
100+
plugin_collections->nbFields = static_cast<int>(fields.size());
101+
plugin_collections->fields = fields.data();
102+
103+
auto creator =
104+
GetPluginRegistry()->getPluginCreator("GroupNormalizationPlugin", "1");
105+
auto group_norm_plugin =
106+
creator->createPlugin("GroupNormalizationPlugin", plugin_collections);
107+
free(plugin_collections);
108+
109+
auto group_norm_plugin_layer = engine_->network()->addPluginV2(
110+
plugin_inputs.data(), plugin_inputs.size(), *group_norm_plugin);
111+
112+
auto output_name = op_desc.Output("Y")[0];
113+
RreplenishLayerAndOutput(group_norm_plugin_layer, "group_norm",
114+
{output_name}, test_mode);
115+
}
116+
};
117+
118+
} // namespace tensorrt
119+
} // namespace inference
120+
} // namespace paddle
121+
122+
REGISTER_TRT_OP_CONVERTER(group_norm, GroupNormOpConverter);

paddle/fluid/inference/tensorrt/op_teller.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ struct SimpleOpTypeSetTeller : public Teller {
4242
teller_set.insert("multihead_matmul");
4343
teller_set.insert("skip_layernorm");
4444
teller_set.insert("slice");
45+
#endif
46+
#if IS_TRT_VERSION_GE(7130)
47+
teller_set.insert("group_norm");
4548
#endif
4649
}
4750

@@ -150,6 +153,21 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
150153
}
151154
}
152155
}
156+
if (op_type == "group_norm") {
157+
bool has_attrs = (desc.HasAttr("epsilon") && desc.HasAttr("groups"));
158+
if (has_attrs == false) return false;
159+
160+
auto registry = GetPluginRegistry();
161+
if (registry == nullptr) return false;
162+
}
163+
if (op_type == "concat") {
164+
if (!desc.HasAttr("axis")) {
165+
return false;
166+
} else {
167+
int axis = BOOST_GET_CONST(int, desc.GetAttr("axis"));
168+
if (axis <= 0) return false;
169+
}
170+
}
153171
if (op_type == "transpose2" || op_type == "transpose") {
154172
if (!desc.HasAttr("axis")) {
155173
return false;
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
from inference_pass_test import InferencePassTest
20+
import paddle.fluid as fluid
21+
import paddle.fluid.core as core
22+
from paddle.fluid.core import PassVersionChecker
23+
from paddle.fluid.core import AnalysisConfig
24+
25+
26+
class TRTGroupNormTest(InferencePassTest):
27+
def setUp(self):
28+
with fluid.program_guard(self.main_program, self.startup_program):
29+
data = fluid.data(
30+
name="data", shape=[-1, 512, 12, 12], dtype="float32")
31+
relu_out = fluid.layers.relu(data)
32+
relu6_out = fluid.layers.relu6(relu_out)
33+
tanh_out = fluid.layers.tanh(relu6_out)
34+
conv_out = fluid.layers.conv2d(
35+
input=tanh_out,
36+
num_filters=512,
37+
filter_size=3,
38+
groups=1,
39+
padding=[1, 1],
40+
bias_attr=False,
41+
act=None)
42+
out = self.append_group_norm(conv_out)
43+
44+
self.feeds = {
45+
"data": np.random.random([1, 512, 12, 12]).astype("float32"),
46+
}
47+
self.enable_trt = True
48+
self.trt_parameters = TRTGroupNormTest.TensorRTParam(
49+
1 << 30, 32, 1, AnalysisConfig.Precision.Float32, False, False)
50+
self.dynamic_shape_params = TRTGroupNormTest.DynamicShapeParam({
51+
'data': [1, 512, 12, 12]
52+
}, {'data': [1, 512, 12, 12]}, {'data': [1, 512, 12, 12]}, False)
53+
self.fetch_list = [out]
54+
55+
def append_group_norm(self, data):
56+
param_attr = fluid.ParamAttr(
57+
name='group_norm_scale',
58+
initializer=fluid.initializer.Constant(value=1.0))
59+
bias_attr = fluid.ParamAttr(
60+
name='group_norm_bias',
61+
initializer=fluid.initializer.Constant(value=0.0))
62+
return fluid.layers.group_norm(
63+
data,
64+
groups=32,
65+
epsilon=0.000009999999747378752,
66+
param_attr=param_attr,
67+
bias_attr=bias_attr)
68+
69+
def test_check_output(self):
70+
if core.is_compiled_with_cuda():
71+
use_gpu = True
72+
self.check_output_with_option(use_gpu)
73+
self.assertTrue(
74+
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
75+
76+
77+
if __name__ == "__main__":
78+
unittest.main()

0 commit comments

Comments
 (0)