@@ -362,6 +362,7 @@ __global__ void SamplingFromProbKernel(DType* probs, IdType* output, IdType* ind
362362 float aggregate (0 );
363363 float u = curand_uniform (&state);
364364
365+ #pragma unroll 2
365366 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
366367 probs_vec.fill (0 );
367368 if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -405,14 +406,10 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType*
405406 reinterpret_cast <SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
406407 smem_sampling);
407408
408- float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
409- SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
410- probs, row_idx, d, temp_storage);
411-
412409 vec_t <float , VEC_SIZE> probs_vec;
413410 float aggregate;
414411 float q = 1 ;
415- double low = 0 , high = max_val ;
412+ double low = 0 , high = 1 . f ;
416413 int sampled_id;
417414 int round = 0 ;
418415 do {
@@ -421,6 +418,7 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType*
421418 __syncthreads ();
422419 float u = curand_uniform (&state) * q;
423420 aggregate = 0 ;
421+ #pragma unroll 2
424422 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
425423 probs_vec.fill (0 );
426424 if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -446,6 +444,7 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType*
446444 double pivot_1 = (pivot_0 + high) / 2 ;
447445
448446 ValueCount<float > aggregate_gt_pivot_0{0 , 0 }, aggregate_gt_pivot_1{0 , 0 };
447+ #pragma unroll 2
449448 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
450449 probs_vec.fill (0 );
451450 if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -522,20 +521,17 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType*
522521 reinterpret_cast <SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
523522 smem_sampling);
524523
525- float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
526- SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
527- probs, row_idx, d, temp_storage);
528-
529524 vec_t <float , VEC_SIZE> probs_vec;
530525 float aggregate;
531526 float q = 1 ;
532- double low = 0 , high = max_val ;
527+ double low = 0 , high = 1 . f ;
533528 int sampled_id;
534529 do {
535530 temp_storage.sampled_id = d;
536531 __syncthreads ();
537532 float u = curand_uniform (&state) * q;
538533 aggregate = 0 ;
534+ #pragma unroll 2
539535 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
540536 probs_vec.fill (0 );
541537 if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -561,6 +557,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType*
561557 double pivot_1 = (pivot_0 + high) / 2 ;
562558
563559 float aggregate_gt_pivot_0 = 0 , aggregate_gt_pivot_1 = 0 ;
560+ #pragma unroll 2
564561 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
565562 probs_vec.fill (0 );
566563 if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -637,6 +634,7 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp
637634
638635 vec_t <float , VEC_SIZE> probs_vec;
639636 float aggregate_gt_pivot = 0 ;
637+ #pragma unroll 2
640638 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
641639 probs_vec.fill (0 );
642640 if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -664,6 +662,7 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp
664662 temp_storage.sampled_id = d;
665663 __syncthreads ();
666664 float u = curand_uniform (&state) * q;
665+ #pragma unroll 2
667666 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
668667 probs_vec.fill (0 );
669668 if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -709,20 +708,17 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr,
709708 reinterpret_cast <SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
710709 smem_sampling);
711710
712- float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
713- SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
714- probs, row_idx, d, temp_storage);
715-
716711 vec_t <float , VEC_SIZE> probs_vec;
717712 float aggregate;
718713 float q = 1 ;
719- double low = 0 , high = max_val ;
714+ double low = 0 , high = 1 . f ;
720715 int sampled_id;
721716 do {
722717 temp_storage.sampled_id = d;
723718 __syncthreads ();
724719 float u = curand_uniform (&state) * q;
725720 aggregate = 0 ;
721+ #pragma unroll 2
726722 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
727723 probs_vec.fill (0 );
728724 if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -748,6 +744,7 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr,
748744 double pivot_1 = (pivot_0 + high) / 2 ;
749745
750746 ValueCount<float > aggregate_gt_pivot_0{0 , 0 }, aggregate_gt_pivot_1{0 , 0 };
747+ #pragma unroll 2
751748 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
752749 probs_vec.fill (0 );
753750 if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -988,6 +985,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
988985 double mid = (low + high) / 2 ;
989986 min_gt_low = high;
990987 max_le_high = low;
988+ #pragma unroll 2
991989 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
992990 probs_vec.fill (0 );
993991 if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -1034,6 +1032,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
10341032 float normalizer = math::ptx_rcp (max (sum_low, 1e-8 ));
10351033
10361034 // normalize
1035+ #pragma unroll 2
10371036 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
10381037 probs_vec.fill (0 );
10391038 if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -1085,6 +1084,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
10851084 double mid = (low + high) / 2 ;
10861085 min_gt_low = high;
10871086 max_le_high = low;
1087+ #pragma unroll 2
10881088 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
10891089 logits_vec.fill (0 );
10901090 if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -1132,6 +1132,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
11321132 }
11331133
11341134 // masking
1135+ #pragma unroll 2
11351136 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
11361137 logits_vec.fill (0 );
11371138 if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -1185,6 +1186,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
11851186 double mid = (low + high) / 2 ;
11861187 min_gt_low = high;
11871188 max_le_high = low;
1189+ #pragma unroll 2
11881190 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
11891191 probs_vec.fill (0 );
11901192 if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -1236,6 +1238,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
12361238 }
12371239
12381240 // normalize
1241+ #pragma unroll 2
12391242 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
12401243 probs_vec.fill (0 );
12411244 if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -1372,6 +1375,7 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
13721375 float sum_relu_q_minus_p = 0 ;
13731376 vec_t <float , VEC_SIZE> q_vec, p_vec;
13741377 float relu_q_minus_p[VEC_SIZE];
1378+ #pragma unroll 2
13751379 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
13761380 q_vec.fill (0 );
13771381 p_vec.fill (0 );
@@ -1403,6 +1407,7 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
14031407 float u = curand_uniform (&curand_state) * sum_relu_q_minus_p;
14041408
14051409 float aggregate_relu_q_minus_p (0 );
1410+ #pragma unroll 2
14061411 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
14071412 q_vec.fill (0 );
14081413 p_vec.fill (0 );
0 commit comments