@@ -50,7 +50,7 @@ class LayerNormPlugin : public PluginTensorRT {
50
50
// TRT will call this func when we need to serialize the configuration of
51
51
// tensorrt.
52
52
// It should not be called by users.
53
- void serialize (void * buffer) override {
53
+ void serialize (void * buffer) override {
54
54
SerializeValue (&buffer, getPluginType ());
55
55
serializeBase (buffer);
56
56
SerializeValue (&buffer, bias_);
@@ -62,7 +62,7 @@ class LayerNormPlugin : public PluginTensorRT {
62
62
}
63
63
64
64
public:
65
- LayerNormPlugin (const float * bias, const int bias_num, const float * scale,
65
+ LayerNormPlugin (const float * bias, const int bias_num, const float * scale,
66
66
const int scale_num, int begin_norm_axis, float eps,
67
67
std::vector<int64_t > mean_shape,
68
68
std::vector<int64_t > variance_shape)
@@ -78,7 +78,7 @@ class LayerNormPlugin : public PluginTensorRT {
78
78
79
79
// It was used for tensorrt deserialization.
80
80
// It should not be called by users.
81
- LayerNormPlugin (void const * serialData, size_t serialLength) {
81
+ LayerNormPlugin (void const * serialData, size_t serialLength) {
82
82
deserializeBase (serialData, serialLength);
83
83
DeserializeValue (&serialData, &serialLength, &bias_);
84
84
DeserializeValue (&serialData, &serialLength, &scale_);
@@ -90,20 +90,153 @@ class LayerNormPlugin : public PluginTensorRT {
90
90
~LayerNormPlugin () {}
91
91
int initialize () override ;
92
92
93
- LayerNormPlugin * clone () const override {
93
+ LayerNormPlugin* clone () const override {
94
94
return new LayerNormPlugin (bias_.data (), bias_.size (), scale_.data (),
95
95
scale_.size (), begin_norm_axis_, eps_,
96
96
mean_shape_, variance_shape_);
97
97
}
98
98
99
- const char * getPluginType () const override { return " layer_norm_plugin" ; }
99
+ const char * getPluginType () const override { return " layer_norm_plugin" ; }
100
100
int getNbOutputs () const override { return 1 ; }
101
- nvinfer1::Dims getOutputDimensions (int index, const nvinfer1::Dims * inputs,
101
+ nvinfer1::Dims getOutputDimensions (int index, const nvinfer1::Dims* inputs,
102
102
int nbInputDims) override ;
103
- int enqueue (int batchSize, const void * const * inputs, void ** outputs,
104
- void * workspace, cudaStream_t stream) override ;
103
+ int enqueue (int batchSize, const void * const * inputs, void ** outputs,
104
+ void * workspace, cudaStream_t stream) override ;
105
105
};
106
106
107
+ class LayerNormPluginDynamic : public DynamicPluginTensorRT {
108
+ public:
109
+ LayerNormPluginDynamic (const float * bias, const int bias_num,
110
+ const float * scale, const int scale_num,
111
+ int begin_norm_axis, float eps,
112
+ std::vector<int64_t > mean_shape,
113
+ std::vector<int64_t > variance_shape)
114
+ : begin_norm_axis_(begin_norm_axis),
115
+ eps_ (eps),
116
+ mean_shape_(mean_shape),
117
+ variance_shape_(variance_shape) {
118
+ bias_.resize (bias_num);
119
+ scale_.resize (scale_num);
120
+ std::copy (bias, bias + bias_num, bias_.data ());
121
+ std::copy (scale, scale + scale_num, scale_.data ());
122
+ }
123
+
124
+ LayerNormPluginDynamic (void const * serialData, size_t serialLength) {
125
+ DeserializeValue (&serialData, &serialLength, &bias_);
126
+ DeserializeValue (&serialData, &serialLength, &scale_);
127
+ DeserializeValue (&serialData, &serialLength, &begin_norm_axis_);
128
+ DeserializeValue (&serialData, &serialLength, &eps_);
129
+ DeserializeValue (&serialData, &serialLength, &mean_shape_);
130
+ DeserializeValue (&serialData, &serialLength, &variance_shape_);
131
+ }
132
+ nvinfer1::IPluginV2DynamicExt* clone () const override {
133
+ return new LayerNormPluginDynamic (bias_.data (), bias_.size (), scale_.data (),
134
+ scale_.size (), begin_norm_axis_, eps_,
135
+ mean_shape_, variance_shape_);
136
+ }
137
+
138
+ const char * getPluginType () const override { return " layernorm_plugin" ; }
139
+ int getNbOutputs () const override { return 1 ; }
140
+ int initialize () override { return 0 ; }
141
+
142
+ size_t getSerializationSize () const override {
143
+ return SerializedSize (bias_) + SerializedSize (scale_) +
144
+ SerializedSize (begin_norm_axis_) + SerializedSize (eps_) +
145
+ SerializedSize (mean_shape_) + SerializedSize (variance_shape_);
146
+ }
147
+
148
+ void serialize (void * buffer) const override {
149
+ SerializeValue (&buffer, bias_);
150
+ SerializeValue (&buffer, scale_);
151
+ SerializeValue (&buffer, begin_norm_axis_);
152
+ SerializeValue (&buffer, eps_);
153
+ SerializeValue (&buffer, mean_shape_);
154
+ SerializeValue (&buffer, variance_shape_);
155
+ }
156
+
157
+ nvinfer1::DimsExprs getOutputDimensions (
158
+ int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
159
+ nvinfer1::IExprBuilder& expr_builder) override ;
160
+
161
+ bool supportsFormatCombination (int pos,
162
+ const nvinfer1::PluginTensorDesc* inOut,
163
+ int nbInputs, int nbOutputs) override ;
164
+
165
+ void configurePlugin (const nvinfer1::DynamicPluginTensorDesc* in,
166
+ int nbInputs,
167
+ const nvinfer1::DynamicPluginTensorDesc* out,
168
+ int nbOutputs) override {}
169
+
170
+ size_t getWorkspaceSize (const nvinfer1::PluginTensorDesc* inputs,
171
+ int nbInputs,
172
+ const nvinfer1::PluginTensorDesc* outputs,
173
+ int nbOutputs) const override {
174
+ return 0 ;
175
+ }
176
+
177
+ int enqueue (const nvinfer1::PluginTensorDesc* inputDesc,
178
+ const nvinfer1::PluginTensorDesc* outputDesc,
179
+ const void * const * inputs, void * const * outputs, void * workspace,
180
+ cudaStream_t stream) override ;
181
+ nvinfer1::DataType getOutputDataType (int index,
182
+ const nvinfer1::DataType* inputTypes,
183
+ int nbInputs) const override ;
184
+
185
+ void destroy () override { delete this ; }
186
+
187
+ private:
188
+ std::vector<float > bias_;
189
+ std::vector<float > scale_;
190
+ framework::Tensor scale_t ;
191
+ framework::Tensor bias_t ;
192
+ framework::Tensor mean_t ;
193
+ framework::Tensor variance_t ;
194
+ int begin_norm_axis_;
195
+ float eps_;
196
+ std::vector<int64_t > mean_shape_;
197
+ std::vector<int64_t > variance_shape_;
198
+ };
199
+
200
+ class LayerNormPluginDynamicCreator : public nvinfer1 ::IPluginCreator {
201
+ public:
202
+ LayerNormPluginDynamicCreator () {}
203
+ const char * getPluginName () const override { return " layernorm_plugin" ; }
204
+
205
+ const char * getPluginVersion () const override { return " 1" ; }
206
+
207
+ const nvinfer1::PluginFieldCollection* getFieldNames () override {
208
+ return &field_collection_;
209
+ }
210
+
211
+ nvinfer1::IPluginV2* createPlugin (
212
+ const char * name, const nvinfer1::PluginFieldCollection* fc) override {
213
+ return nullptr ;
214
+ }
215
+
216
+ nvinfer1::IPluginV2* deserializePlugin (const char * name,
217
+ const void * serial_data,
218
+ size_t serial_length) override {
219
+ auto plugin = new LayerNormPluginDynamic (serial_data, serial_length);
220
+ return plugin;
221
+ }
222
+
223
+ void setPluginNamespace (const char * lib_namespace) override {
224
+ plugin_namespace_ = lib_namespace;
225
+ }
226
+
227
+ const char * getPluginNamespace () const override {
228
+ return plugin_namespace_.c_str ();
229
+ }
230
+
231
+ private:
232
+ std::string plugin_namespace_;
233
+ std::string plugin_name_;
234
+ nvinfer1::PluginFieldCollection field_collection_{0 , nullptr };
235
+ std::vector<nvinfer1::PluginField> plugin_attributes_;
236
+ };
237
+
238
+ REGISTER_TRT_PLUGIN_V2 (LayerNormPluginDynamicCreator);
239
+
107
240
} // namespace plugin
108
241
} // namespace tensorrt
109
242
} // namespace inference
0 commit comments