@@ -742,7 +742,9 @@ def backward_kernel_one_col_block_sparse(
742742 QUERY_HEAD_GROUPS : tl .constexpr ,
743743 QUERY_EXPAND_DIM : tl .constexpr ,
744744 RETURN_SEL_GRADS : tl .constexpr ,
745- OFF_SEL_KV_BLOCKS : tl .constexpr
745+ OFF_SEL_KV_BLOCKS : tl .constexpr ,
746+ BLOCK_DV_USE_DOT : tl .constexpr ,
747+ BLOCK_DK_USE_DOT : tl .constexpr ,
746748):
747749 # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
748750
@@ -918,23 +920,33 @@ def backward_kernel_one_col_block_sparse(
918920
919921 p = tl .exp (masked_qk * softmax_scale - lse_i [:, :, None ])
920922
923+ # prepare do
924+
925+ do_expanded = tl .expand_dims (do , 2 )
926+ do_expanded = tl .broadcast_to (do_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM ))
927+ do_expanded = do_expanded .reshape (BLOCK , 16 , BLOCK_HEADDIM )
928+
921929 # take care of block dv
922930
923- block_dv = p .to (do .dtype )[:, :, :, None ] * do [:, :, None , :]
931+ if not BLOCK_DV_USE_DOT :
932+ p_expanded = p .to (do .dtype )
933+ p_expanded = tl .expand_dims (p_expanded , 2 )
934+ p_expanded = tl .broadcast_to (p_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK ))
935+ p_expanded = p_expanded .reshape (BLOCK , QUERY_HEAD_GROUPS * QUERY_EXPAND_DIM , BLOCK )
936+ p_expanded = tl .permute (p_expanded , (0 , 2 , 1 ))
924937
925- block_dv = tl .sum (block_dv , 1 )
938+ block_dv = tl .dot (p_expanded , do_expanded ) / QUERY_EXPAND_DIM
939+ else :
940+ block_dv = p .to (do .dtype )[:, :, :, None ] * do [:, :, None , :]
941+ block_dv = tl .sum (block_dv , 1 )
926942
927943 tl .atomic_add (block_dv_ptrs , block_dv , mask = block_masks [:, None , None ], sem = 'relaxed' )
928944
929945 # get dp
930946
931- do = tl .expand_dims (do , 2 )
932- do = tl .broadcast_to (do , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM ))
933- do = do .reshape (BLOCK , 16 , BLOCK_HEADDIM )
934-
935947 block_v = tl .permute (block_v , (0 , 2 , 1 ))
936948
937- dp = tl .dot (do , block_v )
949+ dp = tl .dot (do_expanded , block_v )
938950
939951 dp = dp .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK )
940952 dp = tl .sum (dp , 2 ) / QUERY_EXPAND_DIM
@@ -963,10 +975,23 @@ def backward_kernel_one_col_block_sparse(
963975 sem = 'relaxed'
964976 )
965977
978+ # ds
979+
980+ ds_expanded = tl .expand_dims (ds , 2 )
981+ ds_expanded = tl .broadcast_to (ds_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK ))
982+ ds_expanded = ds_expanded .reshape (BLOCK , 16 , BLOCK )
983+
966984 # block dk
967985
968- block_dk = ds [:, :, :, None ] * q [:, :, None , :].to (ds .dtype )
969- block_dk = tl .sum (block_dk , 1 )
986+ if BLOCK_DK_USE_DOT :
987+ q_expanded = tl .expand_dims (q , 2 )
988+ q_expanded = tl .broadcast_to (q_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM ))
989+ q_expanded = q_expanded .reshape (BLOCK , 16 , BLOCK_HEADDIM )
990+
991+ block_dk = tl .dot (tl .permute (ds_expanded , (0 , 2 , 1 )), q_expanded .to (ds .dtype )) / QUERY_EXPAND_DIM
992+ else :
993+ block_dk = ds [:, :, :, None ] * q [:, :, None , :].to (ds .dtype )
994+ block_dk = tl .sum (block_dk , 1 )
970995
971996 tl .atomic_add (
972997 block_dk_ptrs ,
@@ -977,11 +1002,7 @@ def backward_kernel_one_col_block_sparse(
9771002
9781003 # block dq
9791004
980- ds_expanded = tl .expand_dims (ds , 2 )
981- ds_expanded = tl .broadcast_to (ds_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK ))
982- ds_expanded = ds_expanded .reshape (BLOCK , 16 , BLOCK )
9831005 ds_expanded = ds_expanded .to (block_k .dtype )
984-
9851006 block_dq = tl .dot (ds_expanded , block_k )
9861007
9871008 block_dq = block_dq .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM )
@@ -1348,6 +1369,8 @@ def backward_kernel(
13481369 RETURN_SEL_GRADS : tl .constexpr ,
13491370 INCLUDE_BLOCK_CAUSAL : tl .constexpr ,
13501371 SLIDING : tl .constexpr ,
1372+ BLOCK_DV_USE_DOT : tl .constexpr ,
1373+ BLOCK_DK_USE_DOT : tl .constexpr ,
13511374):
13521375 off_hb = tl .program_id (1 )
13531376 off_b = off_hb // kv_heads
@@ -1467,7 +1490,9 @@ def backward_kernel(
14671490 QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS ,
14681491 QUERY_EXPAND_DIM = QUERY_EXPAND_DIM ,
14691492 RETURN_SEL_GRADS = RETURN_SEL_GRADS ,
1470- OFF_SEL_KV_BLOCKS = OFF_SEL_KV_BLOCKS
1493+ OFF_SEL_KV_BLOCKS = OFF_SEL_KV_BLOCKS ,
1494+ BLOCK_DV_USE_DOT = BLOCK_DV_USE_DOT ,
1495+ BLOCK_DK_USE_DOT = BLOCK_DK_USE_DOT ,
14711496 )
14721497
14731498def native_sparse_attn_backward (
@@ -1482,7 +1507,8 @@ def native_sparse_attn_backward(
14821507 block_size = 128 ,
14831508 include_block_causal = True ,
14841509 return_sel_grads = False ,
1485- sliding = False
1510+ sliding = False ,
1511+ block_dk_dv_use_dot = None
14861512):
14871513 device = do .device
14881514
@@ -1596,7 +1622,9 @@ def native_sparse_attn_backward(
15961622 EVEN_HEADDIM = BLOCK_HEADDIM == dim ,
15971623 RETURN_SEL_GRADS = return_sel_grads ,
15981624 INCLUDE_BLOCK_CAUSAL = include_block_causal ,
1599- SLIDING = sliding
1625+ SLIDING = sliding ,
1626+ BLOCK_DV_USE_DOT = default (block_dk_dv_use_dot , head_groups > 1 ),
1627+ BLOCK_DK_USE_DOT = default (block_dk_dv_use_dot , head_groups > 1 )
16001628 # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
16011629 # num_warps=num_warps,
16021630 # num_stages=1,
@@ -1619,7 +1647,8 @@ def forward(
16191647 selected_block_indices ,
16201648 fmask ,
16211649 sel_scale ,
1622- include_block_causal
1650+ include_block_causal ,
1651+ block_dk_dv_use_dot
16231652 ):
16241653 dtype = fq .dtype
16251654
@@ -1649,6 +1678,7 @@ def forward(
16491678 head_groups ,
16501679 return_sel_grads ,
16511680 include_block_causal ,
1681+ block_dk_dv_use_dot
16521682 )
16531683
16541684 return out .type (dtype ), lse
@@ -1663,7 +1693,8 @@ def backward(self, ctx, do, _):
16631693 block_size ,
16641694 head_groups ,
16651695 return_sel_grads ,
1666- include_block_causal
1696+ include_block_causal ,
1697+ block_dk_dv_use_dot
16671698 ) = ctx ._saved_variables
16681699
16691700 do = do .half ()
@@ -1679,15 +1710,16 @@ def backward(self, ctx, do, _):
16791710 out , lse , dq , dk , dv ,
16801711 block_size = block_size ,
16811712 include_block_causal = include_block_causal ,
1682- return_sel_grads = return_sel_grads
1713+ return_sel_grads = return_sel_grads ,
1714+ block_dk_dv_use_dot = block_dk_dv_use_dot
16831715 )
16841716
16851717 ret_sel_grads = None
16861718
16871719 if return_sel_grads :
16881720 ret_sel_grads = sel_grads
16891721
1690- return dq , dk , dv , None , None , None , ret_sel_grads , None
1722+ return dq , dk , dv , None , None , None , ret_sel_grads , None , None
16911723
16921724_native_sparse_attend = NSA .apply
16931725
@@ -1709,7 +1741,8 @@ def native_sparse_attend(
17091741 fmask : Bool ['b qh n sel' ] | Bool ['b kh n sel' ],
17101742 sel_scale : Float ['b kh n sel' ] | Float ['b qh n sel' ] | None = None ,
17111743 include_block_causal = True ,
1712- return_lse = False
1744+ return_lse = False ,
1745+ block_dk_dv_use_dot = False
17131746):
17141747 seq_len = fq .shape [- 2 ]
17151748 q_heads , kv_heads , sel_heads = fq .shape [1 ], fk .shape [1 ], selected_block_indices .shape [1 ]
@@ -1730,7 +1763,8 @@ def native_sparse_attend(
17301763 selected_block_indices ,
17311764 fmask ,
17321765 sel_scale ,
1733- include_block_causal
1766+ include_block_causal ,
1767+ block_dk_dv_use_dot
17341768 )
17351769
17361770 if not return_lse :
0 commit comments