@@ -1539,6 +1539,7 @@ struct RenormTempStorage {
1539
1539
struct {
1540
1540
float max_val;
1541
1541
float min_val;
1542
+ float row_sum;
1542
1543
union {
1543
1544
struct {
1544
1545
float values[2 ];
@@ -1565,9 +1566,61 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
1565
1566
uint8_t smem_renorm[];
1566
1567
auto & temp_storage =
1567
1568
reinterpret_cast <RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
1568
- temp_storage.max_val = 0 ;
1569
1569
vec_t <float , VEC_SIZE> probs_vec;
1570
1570
1571
+ // Fast-path: when p >= 1.0 (e.g., p == 1.0), perform simple sum and normalization
1572
+ if (p >= 1 .0f ) {
1573
+ // Stage A: per-thread float accumulation over assigned lanes (vectorized)
1574
+ float thread_sum = 0 .0f ;
1575
+ const uint32_t num_iters = ceil_div (d, BLOCK_THREADS * VEC_SIZE);
1576
+ for (uint32_t i = 0 ; i < num_iters; ++i) {
1577
+ probs_vec.fill (0 .0f );
1578
+ const uint32_t base_idx = (i * BLOCK_THREADS + tx) * VEC_SIZE;
1579
+ if (base_idx < d) {
1580
+ probs_vec.cast_load (probs + row_idx * d + base_idx);
1581
+ }
1582
+ #pragma unroll
1583
+ for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
1584
+ const uint32_t idx = base_idx + j;
1585
+ if (idx < d) thread_sum += probs_vec[j];
1586
+ }
1587
+ }
1588
+
1589
+ // Block reduce (float)
1590
+ float row_sum =
1591
+ BlockReduce<float , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
1592
+ .Sum (thread_sum);
1593
+ // Broadcast via shared
1594
+ if (tx == 0 ) temp_storage.row_sum = row_sum;
1595
+ __syncthreads ();
1596
+ row_sum = temp_storage.row_sum ;
1597
+
1598
+ // Guard against zero sum
1599
+ const float denom = (row_sum <= 1e-8f ) ? 1 .0f : row_sum;
1600
+ const float normalizer = math::ptx_rcp (denom);
1601
+
1602
+ // Stage B: normalize and store
1603
+ for (uint32_t i = 0 ; i < num_iters; ++i) {
1604
+ probs_vec.fill (0 .0f );
1605
+ const uint32_t base_idx = (i * BLOCK_THREADS + tx) * VEC_SIZE;
1606
+ if (base_idx < d) {
1607
+ probs_vec.cast_load (probs + row_idx * d + base_idx);
1608
+ }
1609
+ #pragma unroll
1610
+ for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
1611
+ const uint32_t idx = base_idx + j;
1612
+ float v = probs_vec[j];
1613
+ probs_vec[j] = (idx < d) ? (v * normalizer) : 0 .0f ;
1614
+ }
1615
+ if (base_idx < d) {
1616
+ probs_vec.cast_store (renormed_prob + row_idx * d + base_idx);
1617
+ }
1618
+ }
1619
+ return ; // Exit after fast-path processing
1620
+ }
1621
+
1622
+ // Original Top-P renormalization logic
1623
+ temp_storage.max_val = 0 ;
1571
1624
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
1572
1625
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(probs, row_idx, d,
1573
1626
temp_storage);
0 commit comments