@@ -83,20 +83,18 @@ std::vector<paddle::Tensor> decoding_kernel(
83
83
paddle::Tensor& output_ids,
84
84
paddle::Tensor& parent_ids,
85
85
paddle::Tensor& sequence_length,
86
- std::string decoding_strategy,
87
- int beam_size,
88
- int topk,
89
- float topp,
90
- int head_num_,
91
- int size_per_head_,
92
- int num_layer_,
93
- int start_id_,
94
- int end_id_,
95
- int64_t max_seq_len_,
96
- float beam_search_diversity_rate_,
97
- float alpha,
98
- cublasHandle_t cublas_handle_,
99
- cublasLtHandle_t cublaslt_handle_,
86
+ const std::string& decoding_strategy,
87
+ const int beam_size,
88
+ const int topk,
89
+ const float topp,
90
+ const int head_num_,
91
+ const int size_per_head_,
92
+ const int num_layer_,
93
+ const int start_id_,
94
+ const int end_id_,
95
+ const int64_t max_seq_len_,
96
+ const float beam_search_diversity_rate_,
97
+ const float alpha,
100
98
cudaStream_t stream) {
101
99
int beam_width_ = (decoding_strategy == " beam_search" ||
102
100
decoding_strategy == " beam_search_v2" )
@@ -119,8 +117,9 @@ std::vector<paddle::Tensor> decoding_kernel(
119
117
typedef typename traits_::data_t data_t_;
120
118
121
119
DecodingInitParam<DataType_> decoding_params;
122
- decoding_params.cublas_handle = cublas_handle_;
123
- decoding_params.cublaslt_handle = cublaslt_handle_;
120
+ decoding_params.cublas_handle = CublasHandle::GetInstance ()->cublas_handle_ ;
121
+ decoding_params.cublaslt_handle =
122
+ CublasHandle::GetInstance ()->cublaslt_handle_ ;
124
123
125
124
decoding_params.output_ids = output_ids.mutable_data <int >(input.place ());
126
125
decoding_params.parent_ids = parent_ids.mutable_data <int >(input.place ());
@@ -156,10 +155,14 @@ std::vector<paddle::Tensor> decoding_kernel(
156
155
DecoderInitParam<DataType_>* params =
157
156
new DecoderInitParam<DataType_>[num_layer_];
158
157
158
+ auto q_weight_shape = self_attn_query_weight[0 ].shape ();
159
+ auto k_weight_shape = self_attn_key_weight[0 ].shape ();
160
+ bool fuse_qkv = (q_weight_shape[1 ] == k_weight_shape[1 ]) ? false : true ;
161
+
159
162
for (int i = 0 ; i < num_layer_; i++) {
160
163
params[i].stream = stream;
161
- params[i].cublas_handle = cublas_handle_;
162
- params[i].cublaslt_handle = cublaslt_handle_;
164
+ params[i].cublas_handle = CublasHandle::GetInstance ()-> cublas_handle_ ;
165
+ params[i].cublaslt_handle = CublasHandle::GetInstance ()-> cublaslt_handle_ ;
163
166
164
167
if (decoding_strategy == " beam_search" ||
165
168
decoding_strategy == " beam_search_v2" ) {
@@ -292,7 +295,8 @@ std::vector<paddle::Tensor> decoding_kernel(
292
295
start_id_,
293
296
end_id_,
294
297
beam_search_diversity_rate_,
295
- true ); // is_fuse_topk_softMax
298
+ true , // is_fuse_topk_softMax
299
+ fuse_qkv); // is_fuse_qkv
296
300
297
301
decoding_beam_search_->forward (params, decoding_params);
298
302
@@ -314,7 +318,7 @@ std::vector<paddle::Tensor> decoding_kernel(
314
318
end_id_,
315
319
beam_search_diversity_rate_,
316
320
true , // is_fuse_topk_softMax
317
- false , // is_fuse_qkv
321
+ fuse_qkv , // is_fuse_qkv
318
322
true , // keep_alive_beam
319
323
alpha);
320
324
@@ -338,7 +342,8 @@ std::vector<paddle::Tensor> decoding_kernel(
338
342
start_id_,
339
343
end_id_,
340
344
candidate_num_,
341
- probability_threshold_);
345
+ probability_threshold_,
346
+ fuse_qkv);
342
347
343
348
decoding_sampling_->forward (params, decoding_params);
344
349
@@ -392,24 +397,20 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
392
397
paddle::Tensor& output_ids,
393
398
paddle::Tensor& parent_ids,
394
399
paddle::Tensor& sequence_length,
395
- std::string decoding_strategy,
396
- int beam_size,
397
- int topk,
398
- float topp,
399
- int n_head,
400
- int size_per_head,
401
- int num_layer,
402
- int bos_id,
403
- int eos_id,
404
- int64_t max_len,
405
- float beam_search_diversity_rate,
406
- float alpha) {
400
+ const std::string& decoding_strategy,
401
+ const int beam_size,
402
+ const int topk,
403
+ const float topp,
404
+ const int n_head,
405
+ const int size_per_head,
406
+ const int num_layer,
407
+ const int bos_id,
408
+ const int eos_id,
409
+ const int64_t max_len,
410
+ const float beam_search_diversity_rate,
411
+ const float alpha) {
407
412
auto stream = input.stream ();
408
- cublasHandle_t cublas_handle_;
409
- cublasCreate (&cublas_handle_);
410
- cublasLtHandle_t cublaslt_handle_;
411
- cublasLtCreate (&cublaslt_handle_);
412
- cublasSetStream (cublas_handle_, stream);
413
+ cublasSetStream (CublasHandle::GetInstance ()->cublas_handle_ , stream);
413
414
414
415
std::vector<paddle::Tensor> ret;
415
416
@@ -466,8 +467,6 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
466
467
max_len,
467
468
beam_search_diversity_rate,
468
469
alpha,
469
- cublas_handle_,
470
- cublaslt_handle_,
471
470
stream);
472
471
break ;
473
472
}
@@ -523,8 +522,6 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
523
522
max_len,
524
523
beam_search_diversity_rate,
525
524
alpha,
526
- cublas_handle_,
527
- cublaslt_handle_,
528
525
stream);
529
526
break ;
530
527
}
@@ -536,7 +533,5 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
536
533
}
537
534
}
538
535
539
- cublasDestroy (cublas_handle_);
540
- cublasLtDestroy (cublaslt_handle_);
541
536
return ret;
542
537
}
0 commit comments