@@ -420,7 +420,8 @@ __device__ void filter_and_histogram(T const* in_buf,
420420 IdxT* histogram,
421421 bool select_min,
422422 int pass,
423- bool early_stop)
423+ bool early_stop,
424+ IdxT k)
424425{
425426 constexpr int num_buckets = calc_num_buckets<BitsPerPass>();
426427 __shared__ IdxT histogram_smem[num_buckets];
@@ -893,9 +894,19 @@ __global__ void radix_kernel(T const* in,
893894 int const pass)
894895{
895896 const int64_t batch_id = blockIdx .y ;
896- const IdxT row_len = phase == Phase::Prefill
897- ? rowEnds[batch_id] - rowStarts[batch_id]
898- : rowEnds[batch_id / next_n] - next_n + (batch_id % next_n) + 1 ;
897+
898+ IdxT row_len = len;
899+ if (phase == Phase::Prefill)
900+ {
901+ if (rowStarts && rowEnds)
902+ {
903+ row_len = rowEnds[batch_id] - rowStarts[batch_id];
904+ }
905+ }
906+ else
907+ {
908+ row_len = rowEnds[batch_id / next_n] - next_n + (batch_id % next_n) + 1 ;
909+ }
899910
900911 auto counter = counters + batch_id;
901912 IdxT current_k;
@@ -965,7 +976,8 @@ __global__ void radix_kernel(T const* in,
965976 histogram,
966977 select_min,
967978 pass,
968- early_stop);
979+ early_stop,
980+ k);
969981 __threadfence ();
970982
971983 bool isLastBlock = false ;
@@ -1187,7 +1199,8 @@ __device__ bool filter_and_histogram_for_one_block(T const* in_buf,
11871199 Counter<T, IdxT>* counter,
11881200 IdxT* histogram,
11891201 bool select_min,
1190- int pass)
1202+ int pass,
1203+ IdxT k)
11911204{
11921205 constexpr int num_buckets = calc_num_buckets<BitsPerPass>();
11931206 for (int i = threadIdx .x ; i < num_buckets * 2 ; i += blockDim .x )
@@ -1371,11 +1384,25 @@ __global__ void radix_topk_one_block_kernel(T const* in,
13711384 __shared__ IdxT histogram[num_buckets * 2 ];
13721385
13731386 const int64_t batch_id = blockIdx .x ;
1374- const IdxT rowStart = phase == Phase::Prefill ? rowStarts[batch_id] : 0 ;
1375- const IdxT rowEnd = phase == Phase::Prefill
1376- ? rowEnds[batch_id]
1377- : rowEnds[batch_id / next_n] - next_n + (batch_id % next_n) + 1 ;
1378- const IdxT row_len = rowEnd - rowStart;
1387+
1388+ IdxT rowStart = 0 ;
1389+ IdxT rowEnd = len;
1390+ if (phase == Phase::Prefill)
1391+ {
1392+ if (rowStarts && rowEnds)
1393+ {
1394+ rowStart = rowStarts[batch_id];
1395+ rowEnd = rowEnds[batch_id];
1396+ }
1397+ }
1398+ else
1399+ {
1400+ rowEnd = rowEnds[batch_id / next_n] - next_n + (batch_id % next_n) + 1 ;
1401+ rowStart = 0 ;
1402+ }
1403+
1404+ const IdxT row_len = rowEnd - rowStart;
1405+
13791406 if (threadIdx .x == 0 )
13801407 {
13811408 counter.k = k;
@@ -1448,7 +1475,8 @@ __global__ void radix_topk_one_block_kernel(T const* in,
14481475 &counter,
14491476 histogram,
14501477 select_min,
1451- pass); // @TODO CHECK UPDATE CODE
1478+ pass,
1479+ k); // @TODO CHECK UPDATE CODE
14521480 __syncthreads ();
14531481
14541482 scan<IdxT, BitsPerPass, BlockSize>(histogram + use_one_pass * num_buckets);
@@ -1811,6 +1839,35 @@ void standalone_stable_radix_11bits(void* buf,
18111839 }
18121840}
18131841
1842+ // Explicit template instantiation for standalone_stable_radix_11bits
1843+ template void standalone_stable_radix_11bits<float , int , true , true >(void * buf,
1844+ size_t & buf_size,
1845+ float const * in,
1846+ int batch_size,
1847+ int64_t len,
1848+ int * rowStarts,
1849+ int * rowEnds,
1850+ int k,
1851+ float * out,
1852+ int * out_idx,
1853+ bool greater,
1854+ hipStream_t stream,
1855+ int next_n);
1856+
1857+ template void standalone_stable_radix_11bits<float , int , false , true >(void * buf,
1858+ size_t & buf_size,
1859+ float const * in,
1860+ int batch_size,
1861+ int64_t len,
1862+ int * rowStarts,
1863+ int * rowEnds,
1864+ int k,
1865+ float * out,
1866+ int * out_idx,
1867+ bool greater,
1868+ hipStream_t stream,
1869+ int next_n);
1870+
18141871// AIR TopK end
18151872
18161873static inline __device__ uint32_t floatAsSortableUint (float x)
@@ -2410,6 +2467,9 @@ int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0)
24102467 return buf_size;
24112468}
24122469
2470+ // Explicit template instantiation to ensure the symbol is available for linking
2471+ template int64_t invokeComputeTopkLastDimWorkspaceSize<float >(int32_t numRows, int32_t stride0);
2472+
24132473void top_k_per_row_prefill (const torch::Tensor& logits,
24142474 const torch::Tensor& rowStarts,
24152475 const torch::Tensor& rowEnds,
0 commit comments