@@ -65,6 +65,48 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
65
65
}
66
66
}
67
67
68
+ template <typename T, int THREADBLOCK_SIZE>
69
+ __global__ void beam_topK_kernel_general (const T* log_probs,
70
+ T* tmp_log_probs,
71
+ int * topk_tmp_id_buf,
72
+ T* topk_tmp_val_buf,
73
+ const int k,
74
+ const int vocab_size) {
75
+ const bool IS_FP16 = std::is_same<T, half>::value;
76
+ const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
77
+ typedef cub::BlockReduce<TopK_2<T>, THREADBLOCK_SIZE> BlockReduce;
78
+ __shared__ typename BlockReduce::TempStorage temp_storage;
79
+
80
+ const int tid = threadIdx .x ;
81
+ const int bid = blockIdx .x ;
82
+ TopK_2<T> partial;
83
+
84
+ for (int elem_id = tid; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE) {
85
+ int index = elem_id + bid * vocab_size;
86
+ tmp_log_probs[index] = log_probs[index];
87
+ }
88
+
89
+ for (int ite = 0 ; ite < k; ite++) {
90
+ partial.init ();
91
+ #pragma unroll
92
+ for (int elem_id = tid; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE) {
93
+ int index = elem_id + bid * vocab_size;
94
+ partial.insert (tmp_log_probs[index], index);
95
+ }
96
+
97
+ TopK_2<T> total =
98
+ BlockReduce (temp_storage).Reduce (partial, reduce_topk_op_2<T>);
99
+
100
+ if (tid == 0 ) {
101
+ const int index = bid * k + ite;
102
+ topk_tmp_id_buf[index] = total.p ;
103
+ topk_tmp_val_buf[index] = total.u ;
104
+ tmp_log_probs[total.p ] = -MAX_T_VAL;
105
+ }
106
+ __syncthreads ();
107
+ }
108
+ }
109
+
68
110
#define CASE_K (K ) \
69
111
case K: \
70
112
beam_topK_kernel<T, K, block_size><<<batch_size, block_size, 0 , stream>>> ( \
@@ -554,7 +596,6 @@ void topK_sampling_kernel_kernelLauncher(void* workspace,
554
596
cudaStream_t stream) {
555
597
std::minstd_rand engine;
556
598
int seed = std::random_device ()();
557
-
558
599
const int batch_size = args.batch_size_ ;
559
600
const int vocab_size = args.vocab_size_ ;
560
601
const int candidate_num = args.candidate_num_ ;
@@ -563,25 +604,36 @@ void topK_sampling_kernel_kernelLauncher(void* workspace,
563
604
564
605
int topk_tmp_ids_buf_size =
565
606
args.batch_size_ * args.candidate_num_ ; // type int
607
+ int temp_log_probs_buf_size =
608
+ args.batch_size_ * args.candidate_num_ * vocab_size;
566
609
int topk_tmp_val_buf_size = args.batch_size_ * args.candidate_num_ ; // type T
610
+
611
+ temp_log_probs_buf_size = (int )(ceil (temp_log_probs_buf_size / 4 .)) * 4 ;
567
612
topk_tmp_ids_buf_size = (int )(ceil (topk_tmp_ids_buf_size / 4 .)) * 4 ;
568
613
topk_tmp_val_buf_size = (int )(ceil (topk_tmp_val_buf_size / 4 .)) * 4 ;
569
614
570
615
if (workspace == nullptr ) {
571
- workspace_size = sizeof (int ) * topk_tmp_ids_buf_size +
572
- sizeof (int ) * topk_tmp_val_buf_size;
616
+ workspace_size = sizeof (T) * temp_log_probs_buf_size +
617
+ sizeof (int ) * topk_tmp_ids_buf_size +
618
+ sizeof (T) * topk_tmp_val_buf_size;
573
619
} else {
574
- int * topk_tmp_id_buf = (int *)workspace;
620
+ T* temp_log_probs = (T*)workspace;
621
+ int * topk_tmp_id_buf = (int *)(temp_log_probs + temp_log_probs_buf_size);
575
622
T* topk_tmp_val_buf = (T*)(topk_tmp_id_buf + topk_tmp_ids_buf_size);
576
623
577
624
switch (candidate_num) {
578
625
CASE_K (1 );
579
626
CASE_K (2 );
580
627
CASE_K (4 );
581
628
default :
582
- printf (" [ERROR] Topk kernel does not support candidate_num = %d \n " ,
583
- candidate_num);
584
- exit (0 );
629
+ beam_topK_kernel_general<
630
+ T,
631
+ block_size><<<batch_size, block_size, 0 , stream>>> (log_probs,
632
+ temp_log_probs,
633
+ topk_tmp_id_buf,
634
+ topk_tmp_val_buf,
635
+ candidate_num,
636
+ vocab_size);
585
637
break ;
586
638
}
587
639
sampling<T><<<batch_size, candidate_num, 0 , stream>>> (topk_tmp_id_buf,
0 commit comments