Skip to content

Commit 9742811

Browse files
author
Pei Yang
authored
[Cherry pick] Trt ernie serialization (#25956)
* solve conflict * fix crash when trt not found in python; update unittest model path
1 parent 897305d commit 9742811

19 files changed

+644
-176
lines changed

paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2-
32
Licensed under the Apache License, Version 2.0 (the "License");
43
you may not use this file except in compliance with the License.
54
You may obtain a copy of the License at
6-
75
http://www.apache.org/licenses/LICENSE-2.0
8-
96
Unless required by applicable law or agreed to in writing, software
107
distributed under the License is distributed on an "AS IS" BASIS,
118
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -83,23 +80,10 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
8380
nvinfer1::ILayer* layer = nullptr;
8481

8582
if (engine_->with_dynamic_shape()) {
86-
auto use_fp16 = engine_->WithFp16();
8783
plugin::DynamicPluginTensorRT* plugin = nullptr;
88-
if (use_fp16) {
89-
#ifdef SUPPORTS_CUDA_FP16
90-
plugin = new plugin::EmbEltwiseLayernormPluginDynamic<half>(
91-
input_embs, bias, scale, emb_sizes, bias_size, scale_size, hidden,
92-
eps);
93-
#else
94-
plugin = new plugin::EmbEltwiseLayernormPluginDynamic<float>(
95-
input_embs, bias, scale, emb_sizes, bias_size, scale_size, hidden,
96-
eps);
97-
#endif
98-
} else {
99-
plugin = new plugin::EmbEltwiseLayernormPluginDynamic<float>(
100-
input_embs, bias, scale, emb_sizes, bias_size, scale_size, hidden,
101-
eps);
102-
}
84+
plugin = new plugin::EmbEltwiseLayernormPluginDynamic<float>(
85+
input_embs, bias, scale, emb_sizes, bias_size, scale_size, hidden,
86+
eps);
10387
layer = engine_->AddPluginV2(input_ids.data(), input_num, plugin);
10488
} else {
10589
PADDLE_THROW(platform::errors::Fatal(

paddle/fluid/inference/tensorrt/engine.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,23 @@ class TensorRTEngine {
200200
void Deserialize(const std::string& engine_serialized_data) {
201201
freshDeviceId();
202202
infer_ptr<nvinfer1::IRuntime> runtime(createInferRuntime(&logger_));
203-
infer_engine_.reset(runtime->deserializeCudaEngine(
204-
engine_serialized_data.c_str(), engine_serialized_data.size(),
205-
&inference::Singleton<plugin::PluginFactoryTensorRT>::Global()));
203+
if (with_dynamic_shape_) {
204+
#if IS_TRT_VERSION_GE(6000)
205+
infer_engine_.reset(runtime->deserializeCudaEngine(
206+
engine_serialized_data.c_str(), engine_serialized_data.size(),
207+
nullptr));
208+
#else
209+
210+
PADDLE_THROW(platform::errors::PreconditionNotMet(
211+
"To enable dynamic shape support, the TensorRT version should be "
212+
"greater than 6.0.0"));
213+
214+
#endif
215+
} else {
216+
infer_engine_.reset(runtime->deserializeCudaEngine(
217+
engine_serialized_data.c_str(), engine_serialized_data.size(),
218+
&inference::Singleton<plugin::PluginFactoryTensorRT>::Global()));
219+
}
206220
PADDLE_ENFORCE(infer_engine_ != nullptr,
207221
"build cuda engine failed when deserialize engine info.!");
208222
}

paddle/fluid/inference/tensorrt/helper.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) {
5656
return static_cast<nvinfer1::IRuntime*>(
5757
dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION));
5858
}
59+
static nvinfer1::IPluginRegistry* getPluginRegistry() {
60+
return static_cast<nvinfer1::IPluginRegistry*>(dy::getPluginRegistry());
61+
}
5962

6063
// A logger for create TensorRT infer builder.
6164
class NaiveLogger : public nvinfer1::ILogger {

paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu

Lines changed: 15 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -33,53 +33,29 @@ namespace plugin {
3333

3434
template <typename T>
3535
int EmbEltwiseLayernormPluginDynamic<T>::initialize() {
36-
int nb_emb = embs_.size();
37-
std::vector<void *> ptr_vector(nb_emb);
38-
std::vector<std::vector<half>> emb_fp16(nb_emb);
39-
40-
if (sizeof(T) == sizeof(float)) {
41-
// FP32
42-
for (int i = 0; i < nb_emb; ++i) {
43-
ptr_vector[i] = embs_[i];
44-
}
45-
} else {
46-
// FP16
47-
for (int i = 0; i < nb_emb; ++i) {
48-
auto emb_size = emb_sizes_[i];
49-
auto &tmp = emb_fp16[i];
50-
tmp.resize(emb_size);
51-
52-
for (int j = 0; j < emb_size; ++j) {
53-
tmp[j] = static_cast<half>(embs_[i][j]);
54-
}
55-
ptr_vector[i] = tmp.data();
56-
}
57-
}
5836
embs_gpu_.resize(embs_.size());
5937
for (int i = 0; i < embs_.size(); i++) {
60-
cudaMalloc(&embs_gpu_[i], sizeof(T) * emb_sizes_[i]);
61-
cudaMemcpy(embs_gpu_[i], ptr_vector[i], emb_sizes_[i] * sizeof(T),
62-
cudaMemcpyHostToDevice);
38+
if (embs_[i]) {
39+
cudaMalloc(&embs_gpu_[i], sizeof(float) * emb_sizes_[i]);
40+
cudaMemcpy(embs_gpu_[i], embs_[i], emb_sizes_[i] * sizeof(float),
41+
cudaMemcpyHostToDevice);
42+
}
6343
}
6444

65-
cudaMalloc(&bias_gpu_, sizeof(float) * bias_size_);
66-
cudaMemcpy(bias_gpu_, bias_, bias_size_ * sizeof(float),
67-
cudaMemcpyHostToDevice);
68-
cudaMalloc(&scale_gpu_, sizeof(float) * scale_size_);
69-
cudaMemcpy(scale_gpu_, scale_, scale_size_ * sizeof(float),
70-
cudaMemcpyHostToDevice);
71-
72-
return 0;
73-
}
45+
if (bias_) {
46+
cudaMalloc(&bias_gpu_, sizeof(float) * bias_size_);
47+
cudaMemcpy(bias_gpu_, bias_, bias_size_ * sizeof(float),
48+
cudaMemcpyHostToDevice);
49+
}
50+
if (scale_) {
51+
cudaMalloc(&scale_gpu_, sizeof(float) * scale_size_);
52+
cudaMemcpy(scale_gpu_, scale_, scale_size_ * sizeof(float),
53+
cudaMemcpyHostToDevice);
54+
}
7455

75-
template <typename T>
76-
size_t EmbEltwiseLayernormPluginDynamic<T>::getSerializationSize() const {
7756
return 0;
7857
}
7958

80-
template <typename T>
81-
void EmbEltwiseLayernormPluginDynamic<T>::serialize(void *buffer) const {}
82-
8359
template <typename T>
8460
nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic<T>::getOutputDimensions(
8561
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,

paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h

Lines changed: 122 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,42 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
4444
hidden_size_(hidden_size),
4545
eps_(eps) {}
4646

47-
EmbEltwiseLayernormPluginDynamic(void const* serialData,
48-
size_t serialLength) {}
47+
EmbEltwiseLayernormPluginDynamic(void const* serial_data,
48+
size_t serial_length) {
49+
DeserializeValue(&serial_data, &serial_length, &emb_sizes_);
50+
51+
embs_gpu_.resize(emb_sizes_.size());
52+
embs_.resize(emb_sizes_.size());
53+
for (size_t i = 0; i < emb_sizes_.size(); i++) {
54+
cudaMalloc(&embs_gpu_[i], sizeof(float) * emb_sizes_[i]);
55+
cudaMemcpy(embs_gpu_[i], serial_data, emb_sizes_[i] * sizeof(float),
56+
cudaMemcpyHostToDevice);
57+
reinterpret_cast<char const*&>(serial_data) +=
58+
emb_sizes_[i] * sizeof(float);
59+
serial_length -= emb_sizes_[i] * sizeof(float);
60+
embs_[i] = nullptr;
61+
}
62+
DeserializeValue(&serial_data, &serial_length, &bias_size_);
63+
DeserializeValue(&serial_data, &serial_length, &scale_size_);
64+
65+
cudaMalloc(&bias_gpu_, sizeof(float) * bias_size_);
66+
cudaMemcpy(bias_gpu_, serial_data, bias_size_ * sizeof(float),
67+
cudaMemcpyHostToDevice);
68+
bias_ = nullptr;
69+
reinterpret_cast<char const*&>(serial_data) += bias_size_ * sizeof(float);
70+
serial_length -= bias_size_ * sizeof(float);
71+
72+
cudaMalloc(&scale_gpu_, sizeof(float) * scale_size_);
73+
cudaMemcpy(scale_gpu_, serial_data, scale_size_ * sizeof(float),
74+
cudaMemcpyHostToDevice);
75+
scale_ = nullptr;
76+
reinterpret_cast<char const*&>(serial_data) += scale_size_ * sizeof(float);
77+
serial_length -= scale_size_ * sizeof(float);
78+
79+
DeserializeValue(&serial_data, &serial_length, &hidden_size_);
80+
DeserializeValue(&serial_data, &serial_length, &eps_);
81+
}
82+
4983
nvinfer1::IPluginV2DynamicExt* clone() const override {
5084
return new EmbEltwiseLayernormPluginDynamic(
5185
embs_, bias_, scale_, emb_sizes_, bias_size_, scale_size_, hidden_size_,
@@ -58,36 +92,66 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
5892
int getNbOutputs() const override { return 1; }
5993
int initialize() override;
6094

61-
size_t getSerializationSize() const override;
62-
void serialize(void* buffer) const override;
95+
size_t getSerializationSize() const override {
96+
int sum_num = 0;
97+
sum_num += SerializedSize(emb_sizes_);
98+
99+
for (size_t i = 0; i < emb_sizes_.size(); i++) {
100+
sum_num += emb_sizes_[i] * sizeof(float);
101+
}
102+
103+
sum_num += SerializedSize(bias_size_);
104+
sum_num += SerializedSize(scale_size_);
105+
106+
sum_num += (bias_size_ + scale_size_) * sizeof(float);
107+
sum_num += SerializedSize(hidden_size_);
108+
sum_num += SerializedSize(eps_);
109+
// sum_num += SerializedSize(with_fp16_);
110+
111+
return sum_num;
112+
}
113+
114+
void serialize(void* buffer) const override {
115+
// SerializeValue(&buffer, with_fp16_);
116+
SerializeValue(&buffer, emb_sizes_);
117+
for (size_t i = 0; i < emb_sizes_.size(); i++) {
118+
SerializeCudaPointer(&buffer, embs_gpu_[i], emb_sizes_[i]);
119+
}
120+
SerializeValue(&buffer, bias_size_);
121+
SerializeValue(&buffer, scale_size_);
122+
SerializeCudaPointer(&buffer, bias_gpu_, bias_size_);
123+
SerializeCudaPointer(&buffer, scale_gpu_, scale_size_);
124+
SerializeValue(&buffer, hidden_size_);
125+
SerializeValue(&buffer, eps_);
126+
}
63127

64128
nvinfer1::DimsExprs getOutputDimensions(
65129
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
66130
nvinfer1::IExprBuilder& expr_builder) override;
67131

68132
bool supportsFormatCombination(int pos,
69-
const nvinfer1::PluginTensorDesc* inOut,
70-
int nbInputs, int nbOutputs) override;
133+
const nvinfer1::PluginTensorDesc* in_out,
134+
int nb_inputs, int nb_outputs) override;
71135

72136
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
73-
int nbInputs,
137+
int nb_inputs,
74138
const nvinfer1::DynamicPluginTensorDesc* out,
75-
int nbOutputs) override {}
139+
int nb_outputs) override {}
76140

77141
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
78-
int nbInputs,
142+
int nb_inputs,
79143
const nvinfer1::PluginTensorDesc* outputs,
80-
int nbOutputs) const override {
144+
int nb_outputs) const override {
81145
return 0;
82146
}
83147

84-
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
85-
const nvinfer1::PluginTensorDesc* outputDesc,
148+
int enqueue(const nvinfer1::PluginTensorDesc* input_desc,
149+
const nvinfer1::PluginTensorDesc* output_desc,
86150
const void* const* inputs, void* const* outputs, void* workspace,
87151
cudaStream_t stream) override;
88152
nvinfer1::DataType getOutputDataType(int index,
89-
const nvinfer1::DataType* inputTypes,
90-
int nbInputs) const override;
153+
const nvinfer1::DataType* input_types,
154+
int nb_inputs) const override;
91155

92156
void destroy() override { delete this; }
93157

@@ -99,14 +163,57 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
99163
// data on devices
100164
float* bias_gpu_;
101165
float* scale_gpu_;
102-
std::vector<T*> embs_gpu_;
166+
std::vector<float*> embs_gpu_;
103167

104168
std::vector<int> emb_sizes_;
105169
int bias_size_;
106170
int scale_size_;
107171
int hidden_size_;
108172
float eps_;
109173
};
174+
175+
class EmbEltwiseLayernormPluginV2Creator : public nvinfer1::IPluginCreator {
176+
public:
177+
EmbEltwiseLayernormPluginV2Creator() {}
178+
const char* getPluginName() const override {
179+
return "fused_embedding_eltwise_layernorm_plugin";
180+
}
181+
182+
const char* getPluginVersion() const override { return "1"; }
183+
184+
const nvinfer1::PluginFieldCollection* getFieldNames() override {
185+
return &field_collection_;
186+
}
187+
188+
nvinfer1::IPluginV2* createPlugin(
189+
const char* name, const nvinfer1::PluginFieldCollection* fc) override {
190+
return nullptr;
191+
}
192+
193+
nvinfer1::IPluginV2* deserializePlugin(const char* name,
194+
const void* serial_data,
195+
size_t serial_length) override {
196+
return new EmbEltwiseLayernormPluginDynamic<float>(serial_data,
197+
serial_length);
198+
}
199+
200+
void setPluginNamespace(const char* lib_namespace) override {
201+
plugin_namespace_ = lib_namespace;
202+
}
203+
204+
const char* getPluginNamespace() const override {
205+
return plugin_namespace_.c_str();
206+
}
207+
208+
private:
209+
std::string plugin_namespace_;
210+
std::string plugin_name_;
211+
nvinfer1::PluginFieldCollection field_collection_;
212+
std::vector<nvinfer1::PluginField> plugin_attributes_;
213+
};
214+
215+
REGISTER_TRT_PLUGIN_V2(EmbEltwiseLayernormPluginV2Creator);
216+
110217
#endif
111218
} // namespace plugin
112219
} // namespace tensorrt

paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,6 @@ int GeluPlugin::enqueue(int batch_size, const void* const* inputs,
132132

133133
// Dynamic Plugin below.
134134
#if IS_TRT_VERSION_GE(6000)
135-
size_t GeluPluginDynamic::getSerializationSize() const { return 0; }
136-
137-
void GeluPluginDynamic::serialize(void* buffer) const {}
138135

139136
nvinfer1::DimsExprs GeluPluginDynamic::getOutputDimensions(
140137
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,

0 commit comments

Comments
 (0)