@@ -590,6 +590,7 @@ def backward_kernel_one_col_block_sparse(
590590 V ,
591591 kv_block_indices ,
592592 kv_block_mask ,
593+ kv_block_grads ,
593594 DO ,
594595 DQ ,
595596 DK ,
@@ -619,6 +620,7 @@ def backward_kernel_one_col_block_sparse(
619620 BLOCK : tl .constexpr ,
620621 QUERY_HEAD_GROUPS : tl .constexpr ,
621622 QUERY_EXPAND_DIM : tl .constexpr ,
623+ RETURN_SEL_GRADS : tl .constexpr ,
622624 OFF_SEL_KV_BLOCKS : tl .constexpr
623625):
624626 # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
@@ -638,9 +640,6 @@ def backward_kernel_one_col_block_sparse(
638640
639641 # initialize pointers to value-like data
640642
641- k_ptrs = K + (offs_n [:, None ] * stride_kn + offs_d [None , :])
642- v_ptrs = V + (offs_n [:, None ] * stride_vn + offs_d [None , :])
643-
644643 q_ptrs = (
645644 Q +
646645 offs_g [None , :, None ] * stride_qh +
@@ -794,9 +793,9 @@ def backward_kernel_one_col_block_sparse(
794793 block_qk = block_qk .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK )
795794 qk = tl .sum (block_qk , 2 ) / QUERY_EXPAND_DIM
796795
797- qk += tl .where (block_masks [:, None , None ], 0 , float ("-inf" ))
796+ masked_qk = qk + tl .where (block_masks [:, None , None ], 0 , float ("-inf" ))
798797
799- p = tl .exp (qk * softmax_scale - lse_i [:, :, None ])
798+ p = tl .exp (masked_qk * softmax_scale - lse_i [:, :, None ])
800799
801800 # take care of block dv
802801
@@ -823,6 +822,26 @@ def backward_kernel_one_col_block_sparse(
823822
824823 ds = (p * (dp - Di [:, :, None ]) * softmax_scale )
825824
825+ # maybe return gradients for better differentiable topk
826+
827+ if RETURN_SEL_GRADS :
828+
829+ kv_block_grads_ptrs = (
830+ kv_block_grads +
831+ offs_m * stride_kvbl_m
832+ )
833+
834+ sel_grads = ds * qk
835+ sel_grads = tl .where (block_masks [:, None , None ], sel_grads , 0. )
836+ sel_grads = sel_grads .reshape (BLOCK , QUERY_HEAD_GROUPS * BLOCK )
837+ sel_grads = tl .sum (sel_grads , 1 )
838+
839+ tl .atomic_add (
840+ kv_block_grads_ptrs + OFF_SEL_KV_BLOCKS ,
841+ sel_grads ,
842+ sem = 'relaxed'
843+ )
844+
826845 # block dk
827846
828847 block_dk = ds [:, :, :, None ] * q [:, :, None , :].to (ds .dtype )
@@ -1145,6 +1164,7 @@ def backward_kernel(
11451164 V ,
11461165 kv_block_indices ,
11471166 kv_block_mask ,
1167+ kv_block_grads ,
11481168 DO ,
11491169 DQ ,
11501170 DK ,
@@ -1192,19 +1212,16 @@ def backward_kernel(
11921212 BLOCK : tl .constexpr ,
11931213 QUERY_HEAD_GROUPS : tl .constexpr ,
11941214 QUERY_EXPAND_DIM : tl .constexpr ,
1215+ RETURN_SEL_GRADS : tl .constexpr ,
11951216 INCLUDE_BLOCK_CAUSAL : tl .constexpr
11961217):
11971218 off_hb = tl .program_id (1 )
11981219 off_b = off_hb // kv_heads
11991220 off_h = off_hb % kv_heads
12001221 off_qh = off_h * QUERY_HEAD_GROUPS
12011222
1202- if INCLUDE_BLOCK_CAUSAL :
1203- IS_CAUSAL = tl .program_id (0 ) == 0
1204- OFF_SEL_KV_BLOCKS = tl .program_id (0 ) - 1
1205- else :
1206- IS_CAUSAL = False
1207- OFF_SEL_KV_BLOCKS = tl .program_id (0 )
1223+ OFF_SEL_KV_BLOCKS = tl .program_id (0 ) - int (INCLUDE_BLOCK_CAUSAL )
1224+ IS_CAUSAL = INCLUDE_BLOCK_CAUSAL and tl .program_id (0 ) == 0
12081225
12091226 # offset pointers for batch/head
12101227
@@ -1220,6 +1237,7 @@ def backward_kernel(
12201237
12211238 kv_block_indices += off_b * stride_kvbl_b + off_h * stride_kvbl_h
12221239 kv_block_mask += off_b * stride_kvbl_b + off_h * stride_kvbl_h
1240+ kv_block_grads += off_b * stride_kvbl_b + off_h * stride_kvbl_h
12231241
12241242 # pointer to row-wise quantities in value-like data
12251243
@@ -1283,6 +1301,7 @@ def backward_kernel(
12831301 V ,
12841302 kv_block_indices ,
12851303 kv_block_mask ,
1304+ kv_block_grads ,
12861305 DO ,
12871306 DQ ,
12881307 DK ,
@@ -1312,6 +1331,7 @@ def backward_kernel(
13121331 BLOCK = BLOCK ,
13131332 QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS ,
13141333 QUERY_EXPAND_DIM = QUERY_EXPAND_DIM ,
1334+ RETURN_SEL_GRADS = RETURN_SEL_GRADS ,
13151335 OFF_SEL_KV_BLOCKS = OFF_SEL_KV_BLOCKS
13161336 )
13171337
@@ -1320,11 +1340,13 @@ def native_sparse_attn_backward(
13201340 q , k , v ,
13211341 kv_block_indices ,
13221342 kv_block_mask ,
1343+ kv_block_grads ,
13231344 o ,
13241345 lse ,
13251346 dq , dk , dv ,
13261347 block_size = 128 ,
1327- include_block_causal = True
1348+ include_block_causal = True ,
1349+ return_sel_grads = False
13281350):
13291351 device = do .device
13301352
@@ -1387,6 +1409,7 @@ def native_sparse_attn_backward(
13871409 v ,
13881410 kv_block_indices ,
13891411 kv_block_mask ,
1412+ kv_block_grads ,
13901413 do ,
13911414 dq ,
13921415 dk ,
@@ -1436,6 +1459,7 @@ def native_sparse_attn_backward(
14361459 EVEN_M = divisible_by (seqlen_q , block_size ),
14371460 EVEN_N = divisible_by (seqlen_k , block_size ),
14381461 EVEN_HEADDIM = BLOCK_HEADDIM == dim ,
1462+ RETURN_SEL_GRADS = return_sel_grads ,
14391463 INCLUDE_BLOCK_CAUSAL = include_block_causal ,
14401464 # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
14411465 # num_warps=num_warps,
@@ -1458,6 +1482,7 @@ def forward(
14581482 block_size ,
14591483 selected_block_indices ,
14601484 fmask ,
1485+ sel_scale ,
14611486 include_block_causal
14621487 ):
14631488 dtype = fq .dtype
@@ -1478,10 +1503,16 @@ def forward(
14781503
14791504 ctx .save_for_backward (fq , fk , fv , selected_block_indices , fmask , out , lse )
14801505
1506+ return_sel_grads = exists (sel_scale )
1507+
1508+ if return_sel_grads :
1509+ assert (sel_scale == 1. ).all (), 'for now, must be straight through as multiplier of 1.'
1510+
14811511 ctx ._saved_variables = (
14821512 block_size ,
14831513 head_groups ,
1484- include_block_causal
1514+ return_sel_grads ,
1515+ include_block_causal ,
14851516 )
14861517
14871518 return out .type (dtype ), lse
@@ -1495,6 +1526,7 @@ def backward(self, ctx, do, _):
14951526 (
14961527 block_size ,
14971528 head_groups ,
1529+ return_sel_grads ,
14981530 include_block_causal
14991531 ) = ctx ._saved_variables
15001532
@@ -1503,15 +1535,23 @@ def backward(self, ctx, do, _):
15031535 dk = torch .zeros (k .shape , dtype = torch .float32 , device = device )
15041536 dv = torch .zeros (v .shape , dtype = torch .float32 , device = device )
15051537
1538+ sel_grads = torch .zeros_like (sel_block_indices ).float ()
1539+
15061540 native_sparse_attn_backward (
15071541 do , q , k , v ,
1508- sel_block_indices , mask ,
1542+ sel_block_indices , mask , sel_grads ,
15091543 out , lse , dq , dk , dv ,
15101544 block_size = block_size ,
1511- include_block_causal = include_block_causal
1545+ include_block_causal = include_block_causal ,
1546+ return_sel_grads = return_sel_grads
15121547 )
15131548
1514- return dq , dk , dv , None , None , None , None
1549+ ret_sel_grads = None
1550+
1551+ if return_sel_grads :
1552+ ret_sel_grads = sel_grads
1553+
1554+ return dq , dk , dv , None , None , None , ret_sel_grads , None
15151555
15161556_native_sparse_attend = NSA .apply
15171557
@@ -1531,6 +1571,7 @@ def native_sparse_attend(
15311571 block_size : int ,
15321572 selected_block_indices : Int ['b qh n sel' ] | Int ['b kh n sel' ],
15331573 fmask : Bool ['b qh n sel' ] | Bool ['b kh n sel' ],
1574+ sel_scale : Float ['b kh n sel' ] | Float ['b qh n sel' ] | None = None ,
15341575 include_block_causal = True ,
15351576 return_lse = False
15361577):
@@ -1550,6 +1591,7 @@ def native_sparse_attend(
15501591 block_size ,
15511592 selected_block_indices ,
15521593 fmask ,
1594+ sel_scale ,
15531595 include_block_causal
15541596 )
15551597
0 commit comments