@@ -330,7 +330,6 @@ def parallel_nsa_compression_bwd_kernel_dkv(
330330
331331
332332@triton .heuristics ({
333- 'BC2' : lambda args : args ['BC' ] // 2 ,
334333 'USE_OFFSETS' : lambda args : args ['offsets' ] is not None
335334})
336335@triton .autotune (
@@ -358,7 +357,6 @@ def parallel_nsa_kernel_topk(
358357 K : tl .constexpr ,
359358 S : tl .constexpr ,
360359 BC : tl .constexpr ,
361- BC2 : tl .constexpr ,
362360 BS : tl .constexpr ,
363361 BK : tl .constexpr ,
364362 USE_OFFSETS : tl .constexpr ,
@@ -428,7 +426,7 @@ def parallel_nsa_kernel_topk(
428426 # [BC]
429427 b_i = tl .full ([BC ], - 1 , dtype = tl .float32 )
430428 o_i = tl .zeros ([BC ], dtype = tl .int32 )
431- m_i = tl .arange (0 , BC ) < BC2
429+ m_i = tl .arange (0 , BC ) < BC // 2
432430 for i_c in range (0 , i_t // BS + 1 , BC ):
433431 o_c = i_c + tl .arange (0 , BC )
434432
@@ -457,8 +455,8 @@ def parallel_nsa_kernel_topk(
457455 else :
458456 b_i , o_i = _bitonic_merge (b_i , o_i .to (tl .int32 ), n_dims , True , n_dims )
459457
460- m_top = tl .arange (0 , 2 ) == 0
461- b_top = tl .sum (m_top [:, None ] * tl .reshape (o_i - 1 , [2 , BC2 ]), 0 )
458+ m_top = tl .arange (0 , BC // S ) == 0
459+ b_top = tl .sum (m_top [:, None ] * tl .reshape (o_i - 1 , [BC // S , S ]), 0 )
462460
463461 p_b = tl .make_block_ptr (block_indices + (bos + i_t ) * H * S , (H * S ,), (1 ,), (i_h * S ,), (S ,), (0 ,))
464462 tl .store (p_b , b_top .to (p_b .dtype .element_ty ))
0 commit comments