@@ -266,8 +266,10 @@ def parallel_nsa_compression_bwd_kernel_dkv(
266266
267267 p_k = tl .make_block_ptr (k + (tl .cdiv (bos , BS ) * H + i_h ) * K , (TC , K ), (H * K , 1 ), (i_c * BC , 0 ), (BC , BK ), (1 , 0 ))
268268 p_v = tl .make_block_ptr (v + (tl .cdiv (bos , BS ) * H + i_h ) * V , (TC , V ), (H * V , 1 ), (i_c * BC , i_v * BV ), (BC , BV ), (1 , 0 ))
269- p_dk = tl .make_block_ptr (dk + (i_v * B * T * H + tl .cdiv (bos , BS ) * H + i_h ) * K , (TC , K ), (H * K , 1 ), (i_c * BC , 0 ), (BC , BK ), (1 , 0 ))
270- p_dv = tl .make_block_ptr (dv + (i_v * B * T * H + tl .cdiv (bos , BS ) * H + i_h ) * V , (TC , V ), (H * V , 1 ), (i_c * BC , i_v * BV ), (BC , BV ), (1 , 0 ))
269+ p_dk = tl .make_block_ptr (dk + (i_v * B * T * H + tl .cdiv (bos , BS ) * H + i_h ) * K ,
270+ (TC , K ), (H * K , 1 ), (i_c * BC , 0 ), (BC , BK ), (1 , 0 ))
271+ p_dv = tl .make_block_ptr (dv + (i_v * B * T * H + tl .cdiv (bos , BS ) * H + i_h ) * V ,
272+ (TC , V ), (H * V , 1 ), (i_c * BC , i_v * BV ), (BC , BV ), (1 , 0 ))
271273
272274 # [BC, BK]
273275 b_k = tl .load (p_k , boundary_check = (0 , 1 ))
@@ -310,6 +312,7 @@ def parallel_nsa_compression_bwd_kernel_dkv(
310312
311313
312314@triton .heuristics ({
315+ 'BC2' : lambda args : args ['BC' ] // 2 ,
313316 'USE_OFFSETS' : lambda args : args ['offsets' ] is not None
314317})
315318@triton .autotune (
@@ -335,6 +338,7 @@ def parallel_nsa_kernel_topk(
335338 K : tl .constexpr ,
336339 S : tl .constexpr ,
337340 BC : tl .constexpr ,
341+ BC2 : tl .constexpr ,
338342 BS : tl .constexpr ,
339343 BK : tl .constexpr ,
340344 USE_OFFSETS : tl .constexpr ,
@@ -402,7 +406,7 @@ def parallel_nsa_kernel_topk(
402406 # [BC]
403407 b_i = tl .full ([BC ], - 1 , dtype = tl .float32 )
404408 o_i = tl .zeros ([BC ], dtype = tl .int32 )
405- m_i = tl .arange (0 , BC ) < S
409+ m_i = tl .arange (0 , BC ) < BC2
406410 for i_c in range (0 , i_t // BS + 1 , BC ):
407411 o_c = i_c + tl .arange (0 , BC )
408412
@@ -432,8 +436,7 @@ def parallel_nsa_kernel_topk(
432436 b_i , o_i = _bitonic_merge (b_i , o_i .to (tl .int32 ), n_dims , True , n_dims )
433437
434438 m_top = tl .arange (0 , 2 ) == 0
435-
436- b_top = tl .sum (m_top [:, None ] * tl .reshape (o_i - 1 , [2 , S ]), 0 )
439+ b_top = tl .sum (m_top [:, None ] * tl .reshape (o_i - 1 , [2 , BC2 ]), 0 )
437440
438441 p_b = tl .make_block_ptr (block_indices + (bos + i_t ) * H * S , (H * S ,), (1 ,), (i_h * S ,), (S ,), (0 ,))
439442 tl .store (p_b , b_top .to (p_b .dtype .element_ty ))
@@ -1369,6 +1372,7 @@ def parallel_nsa_bwd(
13691372 dk = dk .sum (0 )
13701373 return dq , dk , dv
13711374
1375+
13721376@torch .compile
13731377class ParallelNSAFunction (torch .autograd .Function ):
13741378
@@ -1609,7 +1613,6 @@ def parallel_nsa_with_compression(
16091613 offsets = cu_seqlens
16101614 )
16111615 o_slc , o_swa = ParallelNSAFunction .apply (q , k , v , block_indices , block_counts , block_size , window_size , scale , cu_seqlens )
1612- pos = torch .arange (0 , q .shape [1 ]).view (1 , q .shape [1 ], 1 ).to (q .device ).to (block_indices .dtype )
16131616 o = o_slc * g_slc .unsqueeze (- 1 )
16141617 if o_cmp is not None :
16151618 o = torch .addcmul (o , o_cmp , g_cmp .unsqueeze (- 1 ))
0 commit comments