1212import triton .language .core as core
1313from einops import rearrange
1414
15- from fla .ops .common .utils import (prepare_chunk_indices , prepare_lens ,
16- prepare_token_indices )
15+ from fla .ops .common .utils import (prepare_chunk_indices , prepare_chunk_offsets ,
16+ prepare_lens , prepare_token_indices )
1717from fla .utils import autocast_custom_bwd , autocast_custom_fwd , contiguous
1818from native_sparse_attention .ops .utils import _bitonic_merge
1919
@@ -48,6 +48,7 @@ def parallel_nsa_compression_fwd_kernel(
4848 scale ,
4949 offsets ,
5050 token_indices ,
51+ chunk_offsets ,
5152 T ,
5253 H : tl .constexpr ,
5354 HQ : tl .constexpr ,
@@ -67,8 +68,10 @@ def parallel_nsa_compression_fwd_kernel(
6768 i_n , i_t = tl .load (token_indices + i_t * 2 ).to (tl .int32 ), tl .load (token_indices + i_t * 2 + 1 ).to (tl .int32 )
6869 bos , eos = tl .load (offsets + i_n ).to (tl .int32 ), tl .load (offsets + i_n + 1 ).to (tl .int32 )
6970 T = eos - bos
71+ boc = tl .load (chunk_offsets + i_n ).to (tl .int32 )
7072 else :
7173 bos , eos = i_b * T , i_b * T + T
74+ boc = i_b * tl .cdiv (T , BS )
7275
7376 p_q = tl .make_block_ptr (q + (bos + i_t ) * HQ * K , (HQ , K ), (K , 1 ), (i_h * G , 0 ), (G , BK ), (1 , 0 ))
7477
@@ -94,8 +97,8 @@ def parallel_nsa_compression_fwd_kernel(
9497 for i_c in range (0 , NC , BC ):
9598 o_c = i_c + tl .arange (0 , BC )
9699
97- p_k = tl .make_block_ptr (k + (tl . cdiv ( bos , BS ) * H + i_h ) * K , (K , TC ), (1 , H * K ), (0 , i_c ), (BK , BC ), (0 , 1 ))
98- p_v = tl .make_block_ptr (v + (tl . cdiv ( bos , BS ) * H + i_h ) * V , (TC , V ), (H * V , 1 ), (i_c , i_v * BV ), (BC , BV ), (1 , 0 ))
100+ p_k = tl .make_block_ptr (k + (boc * H + i_h ) * K , (K , TC ), (1 , H * K ), (0 , i_c ), (BK , BC ), (0 , 1 ))
101+ p_v = tl .make_block_ptr (v + (boc * H + i_h ) * V , (TC , V ), (H * V , 1 ), (i_c , i_v * BV ), (BC , BV ), (1 , 0 ))
99102 # [BK, BC]
100103 b_k = tl .load (p_k , boundary_check = (0 , 1 ))
101104 # [BC, BV]
@@ -150,6 +153,7 @@ def parallel_nsa_compression_bwd_kernel_dq(
150153 scale ,
151154 offsets ,
152155 token_indices ,
156+ chunk_offsets ,
153157 T ,
154158 B : tl .constexpr ,
155159 H : tl .constexpr ,
@@ -170,8 +174,10 @@ def parallel_nsa_compression_bwd_kernel_dq(
170174 i_n , i_t = tl .load (token_indices + i_t * 2 ).to (tl .int32 ), tl .load (token_indices + i_t * 2 + 1 ).to (tl .int32 )
171175 bos , eos = tl .load (offsets + i_n ).to (tl .int32 ), tl .load (offsets + i_n + 1 ).to (tl .int32 )
172176 T = eos - bos
177+ boc = tl .load (chunk_offsets + i_n ).to (tl .int32 )
173178 else :
174179 bos , eos = i_b * T , i_b * T + T
180+ boc = i_b * tl .cdiv (T , BS )
175181
176182 q += (bos + i_t ) * HQ * K
177183 do += (bos + i_t ) * HQ * V
@@ -206,8 +212,8 @@ def parallel_nsa_compression_bwd_kernel_dq(
206212 b_dq = tl .zeros ([G , BK ], dtype = tl .float32 )
207213 for i_c in range (0 , NC , BC ):
208214 o_c = i_c + tl .arange (0 , BC )
209- p_k = tl .make_block_ptr (k + (tl . cdiv ( bos , BS ) * H + i_h ) * K , (K , TC ), (1 , H * K ), (0 , i_c ), (BK , BC ), (0 , 1 ))
210- p_v = tl .make_block_ptr (v + (tl . cdiv ( bos , BS ) * H + i_h ) * V , (V , TC ), (1 , H * V ), (i_v * BV , i_c ), (BV , BC ), (0 , 1 ))
215+ p_k = tl .make_block_ptr (k + (boc * H + i_h ) * K , (K , TC ), (1 , H * K ), (0 , i_c ), (BK , BC ), (0 , 1 ))
216+ p_v = tl .make_block_ptr (v + (boc * H + i_h ) * V , (V , TC ), (1 , H * V ), (i_v * BV , i_c ), (BV , BC ), (0 , 1 ))
211217 # [BK, BC]
212218 b_k = tl .load (p_k , boundary_check = (0 , 1 ))
213219 # [BV, BC]
@@ -249,6 +255,7 @@ def parallel_nsa_compression_bwd_kernel_dkv(
249255 dv ,
250256 offsets ,
251257 chunk_indices ,
258+ chunk_offsets ,
252259 scale ,
253260 T ,
254261 B : tl .constexpr ,
@@ -270,18 +277,18 @@ def parallel_nsa_compression_bwd_kernel_dkv(
270277 i_n , i_c = tl .load (chunk_indices + i_c * 2 ).to (tl .int32 ), tl .load (chunk_indices + i_c * 2 + 1 ).to (tl .int32 )
271278 bos , eos = tl .load (offsets + i_n ).to (tl .int32 ), tl .load (offsets + i_n + 1 ).to (tl .int32 )
272279 T = eos - bos
280+ boc = tl .load (chunk_offsets + i_n ).to (tl .int32 )
273281 else :
274282 bos , eos = i_b * T , i_b * T + T
283+ boc = i_b * tl .cdiv (T , BS )
275284
276285 # the number of compression representations in total
277286 TC = tl .cdiv (T , BS )
278287
279- p_k = tl .make_block_ptr (k + (tl .cdiv (bos , BS ) * H + i_h ) * K , (TC , K ), (H * K , 1 ), (i_c * BC , 0 ), (BC , BK ), (1 , 0 ))
280- p_v = tl .make_block_ptr (v + (tl .cdiv (bos , BS ) * H + i_h ) * V , (TC , V ), (H * V , 1 ), (i_c * BC , i_v * BV ), (BC , BV ), (1 , 0 ))
281- p_dk = tl .make_block_ptr (dk + (i_v * B * T * H + tl .cdiv (bos , BS ) * H + i_h ) * K ,
282- (TC , K ), (H * K , 1 ), (i_c * BC , 0 ), (BC , BK ), (1 , 0 ))
283- p_dv = tl .make_block_ptr (dv + (i_v * B * T * H + tl .cdiv (bos , BS ) * H + i_h ) * V ,
284- (TC , V ), (H * V , 1 ), (i_c * BC , i_v * BV ), (BC , BV ), (1 , 0 ))
288+ p_k = tl .make_block_ptr (k + (boc * H + i_h ) * K , (TC , K ), (H * K , 1 ), (i_c * BC , 0 ), (BC , BK ), (1 , 0 ))
289+ p_v = tl .make_block_ptr (v + (boc * H + i_h ) * V , (TC , V ), (H * V , 1 ), (i_c * BC , i_v * BV ), (BC , BV ), (1 , 0 ))
290+ p_dk = tl .make_block_ptr (dk + (i_v * B * T * H + boc * H + i_h ) * K , (TC , K ), (H * K , 1 ), (i_c * BC , 0 ), (BC , BK ), (1 , 0 ))
291+ p_dv = tl .make_block_ptr (dv + (i_v * B * T * H + boc * H + i_h ) * V , (TC , V ), (H * V , 1 ), (i_c * BC , i_v * BV ), (BC , BV ), (1 , 0 ))
285292
286293 # [BC, BK]
287294 b_k = tl .load (p_k , boundary_check = (0 , 1 ))
@@ -344,6 +351,7 @@ def parallel_nsa_kernel_topk(
344351 block_indices ,
345352 offsets ,
346353 token_indices ,
354+ chunk_offsets ,
347355 T ,
348356 H : tl .constexpr ,
349357 HQ : tl .constexpr ,
@@ -363,8 +371,10 @@ def parallel_nsa_kernel_topk(
363371 i_n , i_t = tl .load (token_indices + i_t * 2 ).to (tl .int32 ), tl .load (token_indices + i_t * 2 + 1 ).to (tl .int32 )
364372 bos , eos = tl .load (offsets + i_n ).to (tl .int32 ), tl .load (offsets + i_n + 1 ).to (tl .int32 )
365373 T = eos - bos
374+ boc = tl .load (chunk_offsets + i_n ).to (tl .int32 )
366375 else :
367376 bos , eos = i_b * T , i_b * T + T
377+ boc = i_b * tl .cdiv (T , BS )
368378
369379 p_q = tl .make_block_ptr (q + (bos + i_t ) * HQ * K , (HQ , K ), (K , 1 ), (i_h * G , 0 ), (G , BK ), (1 , 0 ))
370380
@@ -382,7 +392,7 @@ def parallel_nsa_kernel_topk(
382392 # 1. lse computation
383393 ################################
384394 if lse is not None :
385- b_lse = tl .load (lse + (tl . cdiv ( bos , BS ) + i_t ) * HQ + i_h * G + tl .arange (0 , G ))
395+ b_lse = tl .load (lse + (bos + i_t ) * HQ + i_h * G + tl .arange (0 , G ))
386396 else :
387397 # max scores for the current block
388398 b_m = tl .full ([G ], float ('-inf' ), dtype = tl .float32 )
@@ -391,7 +401,7 @@ def parallel_nsa_kernel_topk(
391401 for i_c in range (0 , NC , BC ):
392402 o_c = i_c + tl .arange (0 , BC )
393403
394- p_k = tl .make_block_ptr (k + (tl . cdiv ( bos , BS ) * H + i_h ) * K , (K , TC ), (1 , H * K ), (0 , i_c ), (BK , BC ), (0 , 1 ))
404+ p_k = tl .make_block_ptr (k + (boc * H + i_h ) * K , (K , TC ), (1 , H * K ), (0 , i_c ), (BK , BC ), (0 , 1 ))
395405 # [BK, BC]
396406 b_k = tl .load (p_k , boundary_check = (0 , 1 ))
397407
@@ -423,7 +433,7 @@ def parallel_nsa_kernel_topk(
423433 for i_c in range (0 , i_t // BS + 1 , BC ):
424434 o_c = i_c + tl .arange (0 , BC )
425435
426- p_k = tl .make_block_ptr (k + (tl . cdiv ( bos , BS ) * H + i_h ) * K , (K , TC ), (1 , H * K ), (0 , i_c ), (BK , BC ), (0 , 1 ))
436+ p_k = tl .make_block_ptr (k + (boc * H + i_h ) * K , (K , TC ), (1 , H * K ), (0 , i_c ), (BK , BC ), (0 , 1 ))
427437 # [BK, BC]
428438 b_k = tl .load (p_k , boundary_check = (0 , 1 ))
429439 # [G, BC]
@@ -842,6 +852,8 @@ def parallel_nsa_compression_fwd(
842852 NV = triton .cdiv (V , BV )
843853 assert NK == 1 , "The key dimension can not be larger than 256"
844854
855+ chunk_offsets = prepare_chunk_offsets (offsets , BS ) if offsets is not None else None
856+
845857 grid = (T , NV , B * H )
846858 o = torch .empty (B , T , HQ , V , dtype = v .dtype , device = q .device )
847859 lse = torch .empty (B , T , HQ , dtype = torch .float , device = q .device )
@@ -855,6 +867,7 @@ def parallel_nsa_compression_fwd(
855867 scale = scale ,
856868 offsets = offsets ,
857869 token_indices = token_indices ,
870+ chunk_offsets = chunk_offsets ,
858871 T = T ,
859872 H = H ,
860873 HQ = HQ ,
@@ -888,6 +901,15 @@ def parallel_nsa_compression_bwd(
888901 BK = triton .next_power_of_2 (K )
889902 BV = min (128 , triton .next_power_of_2 (v .shape [- 1 ]))
890903 NV = triton .cdiv (V , BV )
904+ if offsets is not None :
905+ lens = prepare_lens (offsets )
906+ chunk_indices = torch .cat ([torch .arange (n ) for n in triton .cdiv (triton .cdiv (lens , BS ), BC ).tolist ()])
907+ chunk_indices = torch .stack ([chunk_indices .eq (0 ).cumsum (0 ) - 1 , chunk_indices ], 1 ).to (offsets )
908+ chunk_offsets = prepare_chunk_offsets (offsets , BS )
909+ NC = len (chunk_indices )
910+ else :
911+ chunk_indices , chunk_offsets = None , None
912+ NC = triton .cdiv (triton .cdiv (T , BS ), BC )
891913
892914 delta = parallel_nsa_bwd_preprocess (o , do )
893915
@@ -904,6 +926,7 @@ def parallel_nsa_compression_bwd(
904926 scale = scale ,
905927 offsets = offsets ,
906928 token_indices = token_indices ,
929+ chunk_offsets = chunk_offsets ,
907930 T = T ,
908931 B = B ,
909932 H = H ,
@@ -918,15 +941,6 @@ def parallel_nsa_compression_bwd(
918941 )
919942 dq = dq .sum (0 )
920943
921- if offsets is not None :
922- lens = prepare_lens (offsets )
923- chunk_indices = torch .cat ([torch .arange (n ) for n in triton .cdiv (triton .cdiv (lens , BS ), BC ).tolist ()])
924- chunk_indices = torch .stack ([chunk_indices .eq (0 ).cumsum (0 ) - 1 , chunk_indices ], 1 ).to (offsets )
925- NC = len (chunk_indices )
926- else :
927- chunk_indices = None
928- NC = triton .cdiv (triton .cdiv (T , BS ), BC )
929-
930944 dk = torch .empty (NV , * k .shape , dtype = k .dtype if NV == 1 else torch .float , device = q .device )
931945 dv = torch .empty (v .shape , dtype = v .dtype , device = q .device )
932946
@@ -942,6 +956,7 @@ def parallel_nsa_compression_bwd(
942956 dv = dv ,
943957 offsets = offsets ,
944958 chunk_indices = chunk_indices ,
959+ chunk_offsets = chunk_offsets ,
945960 scale = scale ,
946961 T = T ,
947962 B = B ,
@@ -1037,6 +1052,7 @@ def parallel_nsa_topk(
10371052
10381053 block_indices = torch .zeros (B , T , H , S , dtype = torch .int32 , device = q .device )
10391054 token_indices = prepare_token_indices (offsets ) if offsets is not None else None
1055+ chunk_offsets = prepare_chunk_offsets (offsets , BS ) if offsets is not None else None
10401056 grid = (T , B * H )
10411057 parallel_nsa_kernel_topk [grid ](
10421058 q = q ,
@@ -1046,6 +1062,7 @@ def parallel_nsa_topk(
10461062 block_indices = block_indices ,
10471063 offsets = offsets ,
10481064 token_indices = token_indices ,
1065+ chunk_offsets = chunk_offsets ,
10491066 T = T ,
10501067 H = H ,
10511068 HQ = HQ ,
0 commit comments