File tree Expand file tree Collapse file tree 2 files changed +83
-70
lines changed
paddlenlp/ops/faster_transformer/src Expand file tree Collapse file tree 2 files changed +83
-70
lines changed Original file line number Diff line number Diff line change @@ -340,9 +340,11 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
340
340
cublasCreate (&cublas_handle_);
341
341
cublasSetStream (cublas_handle_, stream);
342
342
343
+ std::vector<paddle::Tensor> ret;
344
+
343
345
switch (input.type ()) {
344
346
case paddle::DataType::FLOAT16: {
345
- return decoding_kernel<paddle::DataType::FLOAT16>(
347
+ ret = decoding_kernel<paddle::DataType::FLOAT16>(
346
348
input,
347
349
mem_seq_len,
348
350
word_embedding,
@@ -393,9 +395,10 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
393
395
beam_search_diversity_rate,
394
396
cublas_handle_,
395
397
stream);
398
+ break ;
396
399
}
397
400
case paddle::DataType::FLOAT32: {
398
- return decoding_kernel<paddle::DataType::FLOAT32>(
401
+ ret = decoding_kernel<paddle::DataType::FLOAT32>(
399
402
input,
400
403
mem_seq_len,
401
404
word_embedding,
@@ -446,11 +449,16 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
446
449
beam_search_diversity_rate,
447
450
cublas_handle_,
448
451
stream);
452
+ break ;
449
453
}
450
454
default : {
451
455
PD_THROW (
452
456
" NOT supported data type. "
453
457
" Only float16 and float32 are supported. " );
458
+ break ;
454
459
}
455
460
}
461
+
462
+ cublasDestroy (cublas_handle_);
463
+ return ret;
456
464
}
Original file line number Diff line number Diff line change @@ -190,75 +190,80 @@ std::vector<paddle::Tensor> GPT2CUDAForward(
190
190
cublasCreate (&cublas_handle_);
191
191
cublasSetStream (cublas_handle_, stream);
192
192
193
+ std::vector<paddle::Tensor> ret;
194
+
193
195
if (use_fp16) {
194
- return gpt2_kernel<paddle::DataType::FLOAT16>(input,
195
- word_embedding,
196
- self_ln_weight,
197
- self_ln_bias,
198
- self_q_weight,
199
- self_q_bias,
200
- self_k_weight,
201
- self_k_bias,
202
- self_v_weight,
203
- self_v_bias,
204
- self_out_weight,
205
- self_out_bias,
206
- ffn_ln_weight,
207
- ffn_ln_bias,
208
- ffn_inter_weight,
209
- ffn_inter_bias,
210
- ffn_out_weight,
211
- ffn_out_bias,
212
- decoder_ln_weight,
213
- decoder_ln_bias,
214
- positional_embedding_weight,
215
- emb_weight,
216
- output_ids,
217
- topk,
218
- topp,
219
- max_len,
220
- n_head,
221
- size_per_head,
222
- num_layer,
223
- bos_id,
224
- eos_id,
225
- temperature,
226
- cublas_handle_,
227
- stream);
196
+ ret = gpt2_kernel<paddle::DataType::FLOAT16>(input,
197
+ word_embedding,
198
+ self_ln_weight,
199
+ self_ln_bias,
200
+ self_q_weight,
201
+ self_q_bias,
202
+ self_k_weight,
203
+ self_k_bias,
204
+ self_v_weight,
205
+ self_v_bias,
206
+ self_out_weight,
207
+ self_out_bias,
208
+ ffn_ln_weight,
209
+ ffn_ln_bias,
210
+ ffn_inter_weight,
211
+ ffn_inter_bias,
212
+ ffn_out_weight,
213
+ ffn_out_bias,
214
+ decoder_ln_weight,
215
+ decoder_ln_bias,
216
+ positional_embedding_weight,
217
+ emb_weight,
218
+ output_ids,
219
+ topk,
220
+ topp,
221
+ max_len,
222
+ n_head,
223
+ size_per_head,
224
+ num_layer,
225
+ bos_id,
226
+ eos_id,
227
+ temperature,
228
+ cublas_handle_,
229
+ stream);
228
230
} else {
229
- return gpt2_kernel<paddle::DataType::FLOAT32>(input,
230
- word_embedding,
231
- self_ln_weight,
232
- self_ln_bias,
233
- self_q_weight,
234
- self_q_bias,
235
- self_k_weight,
236
- self_k_bias,
237
- self_v_weight,
238
- self_v_bias,
239
- self_out_weight,
240
- self_out_bias,
241
- ffn_ln_weight,
242
- ffn_ln_bias,
243
- ffn_inter_weight,
244
- ffn_inter_bias,
245
- ffn_out_weight,
246
- ffn_out_bias,
247
- decoder_ln_weight,
248
- decoder_ln_bias,
249
- positional_embedding_weight,
250
- emb_weight,
251
- output_ids,
252
- topk,
253
- topp,
254
- max_len,
255
- n_head,
256
- size_per_head,
257
- num_layer,
258
- bos_id,
259
- eos_id,
260
- temperature,
261
- cublas_handle_,
262
- stream);
231
+ ret = gpt2_kernel<paddle::DataType::FLOAT32>(input,
232
+ word_embedding,
233
+ self_ln_weight,
234
+ self_ln_bias,
235
+ self_q_weight,
236
+ self_q_bias,
237
+ self_k_weight,
238
+ self_k_bias,
239
+ self_v_weight,
240
+ self_v_bias,
241
+ self_out_weight,
242
+ self_out_bias,
243
+ ffn_ln_weight,
244
+ ffn_ln_bias,
245
+ ffn_inter_weight,
246
+ ffn_inter_bias,
247
+ ffn_out_weight,
248
+ ffn_out_bias,
249
+ decoder_ln_weight,
250
+ decoder_ln_bias,
251
+ positional_embedding_weight,
252
+ emb_weight,
253
+ output_ids,
254
+ topk,
255
+ topp,
256
+ max_len,
257
+ n_head,
258
+ size_per_head,
259
+ num_layer,
260
+ bos_id,
261
+ eos_id,
262
+ temperature,
263
+ cublas_handle_,
264
+ stream);
263
265
}
266
+
267
+ cublasDestroy (cublas_handle_);
268
+ return ret;
264
269
}
You can’t perform that action at this time.
0 commit comments