@@ -44,8 +44,42 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
44
44
hidden_size_(hidden_size),
45
45
eps_(eps) {}
46
46
47
- EmbEltwiseLayernormPluginDynamic (void const * serialData,
48
- size_t serialLength) {}
47
+ EmbEltwiseLayernormPluginDynamic (void const * serial_data,
48
+ size_t serial_length) {
49
+ DeserializeValue (&serial_data, &serial_length, &emb_sizes_);
50
+
51
+ embs_gpu_.resize (emb_sizes_.size ());
52
+ embs_.resize (emb_sizes_.size ());
53
+ for (size_t i = 0 ; i < emb_sizes_.size (); i++) {
54
+ cudaMalloc (&embs_gpu_[i], sizeof (float ) * emb_sizes_[i]);
55
+ cudaMemcpy (embs_gpu_[i], serial_data, emb_sizes_[i] * sizeof (float ),
56
+ cudaMemcpyHostToDevice);
57
+ reinterpret_cast <char const *&>(serial_data) +=
58
+ emb_sizes_[i] * sizeof (float );
59
+ serial_length -= emb_sizes_[i] * sizeof (float );
60
+ embs_[i] = nullptr ;
61
+ }
62
+ DeserializeValue (&serial_data, &serial_length, &bias_size_);
63
+ DeserializeValue (&serial_data, &serial_length, &scale_size_);
64
+
65
+ cudaMalloc (&bias_gpu_, sizeof (float ) * bias_size_);
66
+ cudaMemcpy (bias_gpu_, serial_data, bias_size_ * sizeof (float ),
67
+ cudaMemcpyHostToDevice);
68
+ bias_ = nullptr ;
69
+ reinterpret_cast <char const *&>(serial_data) += bias_size_ * sizeof (float );
70
+ serial_length -= bias_size_ * sizeof (float );
71
+
72
+ cudaMalloc (&scale_gpu_, sizeof (float ) * scale_size_);
73
+ cudaMemcpy (scale_gpu_, serial_data, scale_size_ * sizeof (float ),
74
+ cudaMemcpyHostToDevice);
75
+ scale_ = nullptr ;
76
+ reinterpret_cast <char const *&>(serial_data) += scale_size_ * sizeof (float );
77
+ serial_length -= scale_size_ * sizeof (float );
78
+
79
+ DeserializeValue (&serial_data, &serial_length, &hidden_size_);
80
+ DeserializeValue (&serial_data, &serial_length, &eps_);
81
+ }
82
+
49
83
nvinfer1::IPluginV2DynamicExt* clone () const override {
50
84
return new EmbEltwiseLayernormPluginDynamic (
51
85
embs_, bias_, scale_, emb_sizes_, bias_size_, scale_size_, hidden_size_,
@@ -58,36 +92,66 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
58
92
int getNbOutputs () const override { return 1 ; }
59
93
int initialize () override ;
60
94
61
- size_t getSerializationSize () const override ;
62
- void serialize (void * buffer) const override ;
95
+ size_t getSerializationSize () const override {
96
+ int sum_num = 0 ;
97
+ sum_num += SerializedSize (emb_sizes_);
98
+
99
+ for (size_t i = 0 ; i < emb_sizes_.size (); i++) {
100
+ sum_num += emb_sizes_[i] * sizeof (float );
101
+ }
102
+
103
+ sum_num += SerializedSize (bias_size_);
104
+ sum_num += SerializedSize (scale_size_);
105
+
106
+ sum_num += (bias_size_ + scale_size_) * sizeof (float );
107
+ sum_num += SerializedSize (hidden_size_);
108
+ sum_num += SerializedSize (eps_);
109
+ // sum_num += SerializedSize(with_fp16_);
110
+
111
+ return sum_num;
112
+ }
113
+
114
+ void serialize (void * buffer) const override {
115
+ // SerializeValue(&buffer, with_fp16_);
116
+ SerializeValue (&buffer, emb_sizes_);
117
+ for (size_t i = 0 ; i < emb_sizes_.size (); i++) {
118
+ SerializeCudaPointer (&buffer, embs_gpu_[i], emb_sizes_[i]);
119
+ }
120
+ SerializeValue (&buffer, bias_size_);
121
+ SerializeValue (&buffer, scale_size_);
122
+ SerializeCudaPointer (&buffer, bias_gpu_, bias_size_);
123
+ SerializeCudaPointer (&buffer, scale_gpu_, scale_size_);
124
+ SerializeValue (&buffer, hidden_size_);
125
+ SerializeValue (&buffer, eps_);
126
+ }
63
127
64
128
nvinfer1::DimsExprs getOutputDimensions (
65
129
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
66
130
nvinfer1::IExprBuilder& expr_builder) override ;
67
131
68
132
bool supportsFormatCombination (int pos,
69
- const nvinfer1::PluginTensorDesc* inOut ,
70
- int nbInputs , int nbOutputs ) override ;
133
+ const nvinfer1::PluginTensorDesc* in_out ,
134
+ int nb_inputs , int nb_outputs ) override ;
71
135
72
136
void configurePlugin (const nvinfer1::DynamicPluginTensorDesc* in,
73
- int nbInputs ,
137
+ int nb_inputs ,
74
138
const nvinfer1::DynamicPluginTensorDesc* out,
75
- int nbOutputs ) override {}
139
+ int nb_outputs ) override {}
76
140
77
141
size_t getWorkspaceSize (const nvinfer1::PluginTensorDesc* inputs,
78
- int nbInputs ,
142
+ int nb_inputs ,
79
143
const nvinfer1::PluginTensorDesc* outputs,
80
- int nbOutputs ) const override {
144
+ int nb_outputs ) const override {
81
145
return 0 ;
82
146
}
83
147
84
- int enqueue (const nvinfer1::PluginTensorDesc* inputDesc ,
85
- const nvinfer1::PluginTensorDesc* outputDesc ,
148
+ int enqueue (const nvinfer1::PluginTensorDesc* input_desc ,
149
+ const nvinfer1::PluginTensorDesc* output_desc ,
86
150
const void * const * inputs, void * const * outputs, void * workspace,
87
151
cudaStream_t stream) override ;
88
152
nvinfer1::DataType getOutputDataType (int index,
89
- const nvinfer1::DataType* inputTypes ,
90
- int nbInputs ) const override ;
153
+ const nvinfer1::DataType* input_types ,
154
+ int nb_inputs ) const override ;
91
155
92
156
void destroy () override { delete this ; }
93
157
@@ -99,14 +163,57 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
99
163
// data on devices
100
164
float * bias_gpu_;
101
165
float * scale_gpu_;
102
- std::vector<T *> embs_gpu_;
166
+ std::vector<float *> embs_gpu_;
103
167
104
168
std::vector<int > emb_sizes_;
105
169
int bias_size_;
106
170
int scale_size_;
107
171
int hidden_size_;
108
172
float eps_;
109
173
};
174
+
175
+ class EmbEltwiseLayernormPluginV2Creator : public nvinfer1 ::IPluginCreator {
176
+ public:
177
+ EmbEltwiseLayernormPluginV2Creator () {}
178
+ const char * getPluginName () const override {
179
+ return " fused_embedding_eltwise_layernorm_plugin" ;
180
+ }
181
+
182
+ const char * getPluginVersion () const override { return " 1" ; }
183
+
184
+ const nvinfer1::PluginFieldCollection* getFieldNames () override {
185
+ return &field_collection_;
186
+ }
187
+
188
+ nvinfer1::IPluginV2* createPlugin (
189
+ const char * name, const nvinfer1::PluginFieldCollection* fc) override {
190
+ return nullptr ;
191
+ }
192
+
193
+ nvinfer1::IPluginV2* deserializePlugin (const char * name,
194
+ const void * serial_data,
195
+ size_t serial_length) override {
196
+ return new EmbEltwiseLayernormPluginDynamic<float >(serial_data,
197
+ serial_length);
198
+ }
199
+
200
+ void setPluginNamespace (const char * lib_namespace) override {
201
+ plugin_namespace_ = lib_namespace;
202
+ }
203
+
204
+ const char * getPluginNamespace () const override {
205
+ return plugin_namespace_.c_str ();
206
+ }
207
+
208
+ private:
209
+ std::string plugin_namespace_;
210
+ std::string plugin_name_;
211
+ nvinfer1::PluginFieldCollection field_collection_;
212
+ std::vector<nvinfer1::PluginField> plugin_attributes_;
213
+ };
214
+
215
+ REGISTER_TRT_PLUGIN_V2 (EmbEltwiseLayernormPluginV2Creator);
216
+
110
217
#endif
111
218
} // namespace plugin
112
219
} // namespace tensorrt
0 commit comments