Skip to content

Commit e5bd7eb

Browse files
authored
Add trt layer norm dynamic (#33448)
* 1, remove layernorm dynamic fp16; 2, let reshape out in dynamic shape (#33535)
1 parent c334d2b commit e5bd7eb

File tree

7 files changed

+336
-23
lines changed

7 files changed

+336
-23
lines changed

paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,6 @@ class LayerNormOpConverter : public OpConverter {
4646
auto* Bias_t = Bias_v->GetMutable<framework::LoDTensor>();
4747
auto* Scale_t = Scale_v->GetMutable<framework::LoDTensor>();
4848

49-
int input_num = 1;
50-
for (int i = 0; i < X->getDimensions().nbDims; i++) {
51-
input_num *= X->getDimensions().d[i];
52-
}
53-
std::vector<int64_t> mean_shape{input_num};
54-
std::vector<int64_t> variance_shape{input_num};
55-
5649
std::unique_ptr<framework::LoDTensor> bias_tensor(
5750
new framework::LoDTensor());
5851
std::unique_ptr<framework::LoDTensor> scale_tensor(
@@ -68,10 +61,33 @@ class LayerNormOpConverter : public OpConverter {
6861
auto* bias_data = bias_tensor->mutable_data<float>(platform::CPUPlace());
6962
auto* scale_data = scale_tensor->mutable_data<float>(platform::CPUPlace());
7063

71-
plugin::LayerNormPlugin* plugin = new plugin::LayerNormPlugin(
72-
bias_data, bias_tensor->numel(), scale_data, scale_tensor->numel(),
73-
begin_norm_axis, eps, mean_shape, variance_shape);
74-
nvinfer1::IPluginLayer* layernorm_layer = engine_->AddPlugin(&X, 1, plugin);
64+
nvinfer1::ILayer* layernorm_layer = nullptr;
65+
if (engine_->with_dynamic_shape()) {
66+
int input_num = 1;
67+
for (int i = begin_norm_axis; i < X->getDimensions().nbDims; i++) {
68+
input_num *= X->getDimensions().d[i];
69+
}
70+
std::vector<int64_t> mean_shape{input_num};
71+
std::vector<int64_t> variance_shape{input_num};
72+
plugin::LayerNormPluginDynamic* plugin =
73+
new plugin::LayerNormPluginDynamic(bias_data, bias_tensor->numel(),
74+
scale_data, scale_tensor->numel(),
75+
begin_norm_axis, eps, mean_shape,
76+
variance_shape);
77+
layernorm_layer = engine_->AddDynamicPlugin(&X, 1, plugin);
78+
} else {
79+
int input_num = 1;
80+
for (int i = begin_norm_axis - 1; i < X->getDimensions().nbDims; i++) {
81+
input_num *= X->getDimensions().d[i];
82+
}
83+
std::vector<int64_t> mean_shape{input_num};
84+
std::vector<int64_t> variance_shape{input_num};
85+
plugin::LayerNormPlugin* plugin = new plugin::LayerNormPlugin(
86+
bias_data, bias_tensor->numel(), scale_data, scale_tensor->numel(),
87+
begin_norm_axis, eps, mean_shape, variance_shape);
88+
layernorm_layer = engine_->AddPlugin(
89+
&X, 1, reinterpret_cast<plugin::PluginTensorRT*>(plugin));
90+
}
7591

7692
auto output_name = op_desc.Output("Y").front();
7793
engine_->SetWeights(op_desc.Input("Bias").front(), std::move(bias_tensor));

paddle/fluid/inference/tensorrt/op_teller.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
700700
}
701701

702702
if (op_type == "reshape" || op_type == "reshape2") {
703-
if (!desc.HasAttr("shape") || with_dynamic_shape) {
703+
if (!desc.HasAttr("shape")) {
704704
return false;
705705
// Paddle-TRT does not support the input tensors: Shape and ShapeTensor
706706
} else if (desc.Input("Shape").size() >= 1 ||

paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,18 @@ int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs,
5757
input_shape.push_back(input_dims.d[i]);
5858
}
5959
const auto input_ddim = framework::make_ddim(input_shape);
60-
auto matrix_dim = framework::flatten_to_2d(input_ddim, begin_norm_axis - 1);
60+
auto matrix_dim = framework::flatten_to_2d(input_ddim, begin_norm_axis);
6161
int feature_size = static_cast<int>(matrix_dim[1]);
62+
PADDLE_ENFORCE_EQ(feature_size, scale_.size(),
63+
platform::errors::InvalidArgument(
64+
"scale's size should be equal to the feature_size,"
65+
"but got feature_size:%d, scale's size:%d.",
66+
feature_size, scale_.size()));
67+
PADDLE_ENFORCE_EQ(feature_size, bias_.size(),
68+
platform::errors::InvalidArgument(
69+
"bias's size should be equal to the feature_size,"
70+
"but got feature_size:%d, bias's size:%d.",
71+
feature_size, bias_.size()));
6272

6373
scale_t.Resize(framework::make_ddim({feature_size}));
6474
bias_t.Resize(framework::make_ddim({feature_size}));
@@ -82,6 +92,103 @@ int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs,
8292
return cudaGetLastError() != cudaSuccess;
8393
}
8494

95+
nvinfer1::DimsExprs LayerNormPluginDynamic::getOutputDimensions(
96+
int output_index, const nvinfer1::DimsExprs *inputDims, int nb_inputs,
97+
nvinfer1::IExprBuilder &expr_builder) {
98+
return inputDims[0];
99+
}
100+
101+
bool LayerNormPluginDynamic::supportsFormatCombination(
102+
int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
103+
int nb_outputs) {
104+
PADDLE_ENFORCE_NOT_NULL(
105+
in_out, platform::errors::InvalidArgument(
106+
"The input of layernorm plugin shoule not be nullptr."));
107+
PADDLE_ENFORCE_LT(
108+
pos, nb_inputs + nb_outputs,
109+
platform::errors::InvalidArgument("The pos(%d) should be less than the "
110+
"num(%d) of the input and the output.",
111+
pos, nb_inputs + nb_outputs));
112+
const nvinfer1::PluginTensorDesc &in = in_out[pos];
113+
if (pos == 0) {
114+
// TODO(Shangzhizhou) FP16 support
115+
return (in.type == nvinfer1::DataType::kFLOAT) &&
116+
(in.format == nvinfer1::TensorFormat::kLINEAR);
117+
}
118+
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
119+
// output
120+
return in.type == prev.type && in.format == prev.format;
121+
}
122+
123+
nvinfer1::DataType LayerNormPluginDynamic::getOutputDataType(
124+
int index, const nvinfer1::DataType *input_types, int nb_inputs) const {
125+
PADDLE_ENFORCE_EQ(index, 0,
126+
platform::errors::InvalidArgument(
127+
"The LayerNormPlugin only has one input, so the "
128+
"index value should be 0, but get %d.",
129+
index));
130+
return input_types[0];
131+
}
132+
133+
int LayerNormPluginDynamic::enqueue(
134+
const nvinfer1::PluginTensorDesc *input_desc,
135+
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
136+
void *const *outputs, void *workspace, cudaStream_t stream) {
137+
const auto &input_dims = input_desc[0].dims;
138+
int begin_norm_axis = begin_norm_axis_;
139+
float eps = eps_;
140+
141+
std::vector<int> input_shape;
142+
for (int i = 0; i < input_dims.nbDims; i++) {
143+
input_shape.push_back(input_dims.d[i]);
144+
}
145+
const auto input_ddim = framework::make_ddim(input_shape);
146+
auto matrix_dim = framework::flatten_to_2d(input_ddim, begin_norm_axis);
147+
int feature_size = static_cast<int>(matrix_dim[1]);
148+
PADDLE_ENFORCE_EQ(feature_size, scale_.size(),
149+
platform::errors::InvalidArgument(
150+
"scale's size should be equal to the feature_size,"
151+
"but got feature_size:%d, scale's size:%d.",
152+
feature_size, scale_.size()));
153+
PADDLE_ENFORCE_EQ(feature_size, bias_.size(),
154+
platform::errors::InvalidArgument(
155+
"bias's size should be equal to the feature_size,"
156+
"but got feature_size:%d, bias's size:%d.",
157+
feature_size, bias_.size()));
158+
int device_id;
159+
cudaGetDevice(&device_id);
160+
auto input_type = input_desc[0].type;
161+
if (input_type == nvinfer1::DataType::kFLOAT) {
162+
VLOG(1) << "TRT Plugin DataType selected. LayerNorm-->fp32";
163+
const float *input = reinterpret_cast<const float *>(inputs[0]);
164+
float *output = static_cast<float *>(outputs[0]);
165+
scale_t.Resize(framework::make_ddim({feature_size}));
166+
bias_t.Resize(framework::make_ddim({feature_size}));
167+
mean_t.Resize(framework::make_ddim(mean_shape_));
168+
variance_t.Resize(framework::make_ddim(variance_shape_));
169+
170+
float *scale_d =
171+
scale_t.mutable_data<float>(platform::CUDAPlace(device_id));
172+
float *bias_d = bias_t.mutable_data<float>(platform::CUDAPlace(device_id));
173+
float *mean_d = mean_t.mutable_data<float>(platform::CUDAPlace(device_id));
174+
float *variance_d =
175+
variance_t.mutable_data<float>(platform::CUDAPlace(device_id));
176+
177+
cudaMemcpyAsync(scale_d, scale_.data(), sizeof(float) * feature_size,
178+
cudaMemcpyHostToDevice, stream);
179+
cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * feature_size,
180+
cudaMemcpyHostToDevice, stream);
181+
182+
paddle::operators::LayerNormDirectCUDAFunctor<float> layer_norm;
183+
layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d,
184+
variance_d, begin_norm_axis, eps);
185+
} else {
186+
PADDLE_THROW(platform::errors::Fatal(
187+
"The LayerNorm TRT Plugin's input type should be float."));
188+
}
189+
return cudaGetLastError() != cudaSuccess;
190+
}
191+
85192
} // namespace plugin
86193
} // namespace tensorrt
87194
} // namespace inference

paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h

Lines changed: 141 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class LayerNormPlugin : public PluginTensorRT {
5050
// TRT will call this func when we need to serialize the configuration of
5151
// tensorrt.
5252
// It should not be called by users.
53-
void serialize(void *buffer) override {
53+
void serialize(void* buffer) override {
5454
SerializeValue(&buffer, getPluginType());
5555
serializeBase(buffer);
5656
SerializeValue(&buffer, bias_);
@@ -62,7 +62,7 @@ class LayerNormPlugin : public PluginTensorRT {
6262
}
6363

6464
public:
65-
LayerNormPlugin(const float *bias, const int bias_num, const float *scale,
65+
LayerNormPlugin(const float* bias, const int bias_num, const float* scale,
6666
const int scale_num, int begin_norm_axis, float eps,
6767
std::vector<int64_t> mean_shape,
6868
std::vector<int64_t> variance_shape)
@@ -78,7 +78,7 @@ class LayerNormPlugin : public PluginTensorRT {
7878

7979
// It was used for tensorrt deserialization.
8080
// It should not be called by users.
81-
LayerNormPlugin(void const *serialData, size_t serialLength) {
81+
LayerNormPlugin(void const* serialData, size_t serialLength) {
8282
deserializeBase(serialData, serialLength);
8383
DeserializeValue(&serialData, &serialLength, &bias_);
8484
DeserializeValue(&serialData, &serialLength, &scale_);
@@ -90,20 +90,153 @@ class LayerNormPlugin : public PluginTensorRT {
9090
~LayerNormPlugin() {}
9191
int initialize() override;
9292

93-
LayerNormPlugin *clone() const override {
93+
LayerNormPlugin* clone() const override {
9494
return new LayerNormPlugin(bias_.data(), bias_.size(), scale_.data(),
9595
scale_.size(), begin_norm_axis_, eps_,
9696
mean_shape_, variance_shape_);
9797
}
9898

99-
const char *getPluginType() const override { return "layer_norm_plugin"; }
99+
const char* getPluginType() const override { return "layer_norm_plugin"; }
100100
int getNbOutputs() const override { return 1; }
101-
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
101+
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs,
102102
int nbInputDims) override;
103-
int enqueue(int batchSize, const void *const *inputs, void **outputs,
104-
void *workspace, cudaStream_t stream) override;
103+
int enqueue(int batchSize, const void* const* inputs, void** outputs,
104+
void* workspace, cudaStream_t stream) override;
105105
};
106106

107+
class LayerNormPluginDynamic : public DynamicPluginTensorRT {
108+
public:
109+
LayerNormPluginDynamic(const float* bias, const int bias_num,
110+
const float* scale, const int scale_num,
111+
int begin_norm_axis, float eps,
112+
std::vector<int64_t> mean_shape,
113+
std::vector<int64_t> variance_shape)
114+
: begin_norm_axis_(begin_norm_axis),
115+
eps_(eps),
116+
mean_shape_(mean_shape),
117+
variance_shape_(variance_shape) {
118+
bias_.resize(bias_num);
119+
scale_.resize(scale_num);
120+
std::copy(bias, bias + bias_num, bias_.data());
121+
std::copy(scale, scale + scale_num, scale_.data());
122+
}
123+
124+
LayerNormPluginDynamic(void const* serialData, size_t serialLength) {
125+
DeserializeValue(&serialData, &serialLength, &bias_);
126+
DeserializeValue(&serialData, &serialLength, &scale_);
127+
DeserializeValue(&serialData, &serialLength, &begin_norm_axis_);
128+
DeserializeValue(&serialData, &serialLength, &eps_);
129+
DeserializeValue(&serialData, &serialLength, &mean_shape_);
130+
DeserializeValue(&serialData, &serialLength, &variance_shape_);
131+
}
132+
nvinfer1::IPluginV2DynamicExt* clone() const override {
133+
return new LayerNormPluginDynamic(bias_.data(), bias_.size(), scale_.data(),
134+
scale_.size(), begin_norm_axis_, eps_,
135+
mean_shape_, variance_shape_);
136+
}
137+
138+
const char* getPluginType() const override { return "layernorm_plugin"; }
139+
int getNbOutputs() const override { return 1; }
140+
int initialize() override { return 0; }
141+
142+
size_t getSerializationSize() const override {
143+
return SerializedSize(bias_) + SerializedSize(scale_) +
144+
SerializedSize(begin_norm_axis_) + SerializedSize(eps_) +
145+
SerializedSize(mean_shape_) + SerializedSize(variance_shape_);
146+
}
147+
148+
void serialize(void* buffer) const override {
149+
SerializeValue(&buffer, bias_);
150+
SerializeValue(&buffer, scale_);
151+
SerializeValue(&buffer, begin_norm_axis_);
152+
SerializeValue(&buffer, eps_);
153+
SerializeValue(&buffer, mean_shape_);
154+
SerializeValue(&buffer, variance_shape_);
155+
}
156+
157+
nvinfer1::DimsExprs getOutputDimensions(
158+
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
159+
nvinfer1::IExprBuilder& expr_builder) override;
160+
161+
bool supportsFormatCombination(int pos,
162+
const nvinfer1::PluginTensorDesc* inOut,
163+
int nbInputs, int nbOutputs) override;
164+
165+
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
166+
int nbInputs,
167+
const nvinfer1::DynamicPluginTensorDesc* out,
168+
int nbOutputs) override {}
169+
170+
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
171+
int nbInputs,
172+
const nvinfer1::PluginTensorDesc* outputs,
173+
int nbOutputs) const override {
174+
return 0;
175+
}
176+
177+
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
178+
const nvinfer1::PluginTensorDesc* outputDesc,
179+
const void* const* inputs, void* const* outputs, void* workspace,
180+
cudaStream_t stream) override;
181+
nvinfer1::DataType getOutputDataType(int index,
182+
const nvinfer1::DataType* inputTypes,
183+
int nbInputs) const override;
184+
185+
void destroy() override { delete this; }
186+
187+
private:
188+
std::vector<float> bias_;
189+
std::vector<float> scale_;
190+
framework::Tensor scale_t;
191+
framework::Tensor bias_t;
192+
framework::Tensor mean_t;
193+
framework::Tensor variance_t;
194+
int begin_norm_axis_;
195+
float eps_;
196+
std::vector<int64_t> mean_shape_;
197+
std::vector<int64_t> variance_shape_;
198+
};
199+
200+
class LayerNormPluginDynamicCreator : public nvinfer1::IPluginCreator {
201+
public:
202+
LayerNormPluginDynamicCreator() {}
203+
const char* getPluginName() const override { return "layernorm_plugin"; }
204+
205+
const char* getPluginVersion() const override { return "1"; }
206+
207+
const nvinfer1::PluginFieldCollection* getFieldNames() override {
208+
return &field_collection_;
209+
}
210+
211+
nvinfer1::IPluginV2* createPlugin(
212+
const char* name, const nvinfer1::PluginFieldCollection* fc) override {
213+
return nullptr;
214+
}
215+
216+
nvinfer1::IPluginV2* deserializePlugin(const char* name,
217+
const void* serial_data,
218+
size_t serial_length) override {
219+
auto plugin = new LayerNormPluginDynamic(serial_data, serial_length);
220+
return plugin;
221+
}
222+
223+
void setPluginNamespace(const char* lib_namespace) override {
224+
plugin_namespace_ = lib_namespace;
225+
}
226+
227+
const char* getPluginNamespace() const override {
228+
return plugin_namespace_.c_str();
229+
}
230+
231+
private:
232+
std::string plugin_namespace_;
233+
std::string plugin_name_;
234+
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
235+
std::vector<nvinfer1::PluginField> plugin_attributes_;
236+
};
237+
238+
REGISTER_TRT_PLUGIN_V2(LayerNormPluginDynamicCreator);
239+
107240
} // namespace plugin
108241
} // namespace tensorrt
109242
} // namespace inference

paddle/fluid/pybind/inference_api.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@ void BindAnalysisConfig(py::module *m) {
511511
py::arg("disable_trt_plugin_fp16") = false)
512512
.def("enable_tensorrt_oss", &AnalysisConfig::EnableTensorRtOSS)
513513
.def("tensorrt_oss_enabled", &AnalysisConfig::tensorrt_oss_enabled)
514+
.def("exp_disable_tensorrt_ops", &AnalysisConfig::Exp_DisableTensorRtOPs)
514515
.def("enable_tensorrt_dla", &AnalysisConfig::EnableTensorRtDLA,
515516
py::arg("dla_core") = 0)
516517
.def("tensorrt_dla_enabled", &AnalysisConfig::tensorrt_dla_enabled)

python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ def check_output_with_option(self,
160160
use_gpu,
161161
atol=1e-5,
162162
flatten=False,
163-
quant=False):
163+
quant=False,
164+
rtol=1e-5):
164165
'''
165166
Check whether calculating on CPU and GPU, enable TensorRT
166167
or disable TensorRT, enable MKLDNN or disable MKLDNN
@@ -260,7 +261,7 @@ def check_output_with_option(self,
260261

261262
self.assertTrue(
262263
np.allclose(
263-
out, tensorrt_output, atol=atol),
264+
out, tensorrt_output, rtol=rtol, atol=atol),
264265
"Output has diff between GPU and TensorRT. ")
265266

266267
# Check whether the mkldnn results and the CPU results are the same.

0 commit comments

Comments
 (0)