@@ -1465,7 +1465,6 @@ def backward_kernel(
14651465 DV ,
14661466 LSE ,
14671467 D ,
1468- SLIDE_O ,
14691468 SLIDE_DO ,
14701469 SLIDE_LSE ,
14711470 SLIDE_D ,
@@ -1522,19 +1521,44 @@ def backward_kernel(
15221521 off_h = off_hb % kv_heads
15231522 off_qh = off_h * QUERY_HEAD_GROUPS
15241523
1525- OFF_SEL_KV_BLOCKS = tl .program_id (0 ) - int (INCLUDE_BLOCK_CAUSAL )
1526- IS_CAUSAL = INCLUDE_BLOCK_CAUSAL and tl .program_id (0 ) == 0
1524+ # determine whether block causal diagonal, sliding, or selected fine kv blocks
1525+
1526+ block_id = tl .program_id (0 )
1527+
1528+ IS_CAUSAL = False
1529+ IS_SLIDING = False
1530+
1531+ do = DO
1532+ lse = LSE
1533+ delta = D
1534+
1535+ if INCLUDE_BLOCK_CAUSAL :
1536+ IS_CAUSAL = block_id == 0
1537+ block_id -= 1
1538+
1539+ if SLIDING :
1540+ IS_SLIDING = block_id == 0
1541+ block_id -= 1
1542+
1543+ if IS_SLIDING :
1544+ do = SLIDE_DO
1545+ lse = SLIDE_LSE
1546+ delta = SLIDE_D
1547+
1548+ OFF_SEL_KV_BLOCKS = block_id
15271549
15281550 # offset pointers for batch/head
15291551
15301552 Q += off_b * stride_qb + off_qh * stride_qh
15311553 K += off_b * stride_kb + off_h * stride_kh
15321554 V += off_b * stride_vb + off_h * stride_vh
1533- DO += off_b * stride_dob + off_qh * stride_doh
1555+
15341556 DQ += off_b * stride_dqb + off_qh * stride_dqh
15351557 DK += off_b * stride_dkb + off_h * stride_dkh
15361558 DV += off_b * stride_dvb + off_h * stride_dvh
15371559
1560+ do += off_b * stride_dob + off_qh * stride_doh
1561+
15381562 # offset pointers for batch/head for selected kv block related
15391563
15401564 kv_block_indices += off_b * stride_kvbl_b + off_h * stride_kvbl_h
@@ -1543,32 +1567,32 @@ def backward_kernel(
15431567
15441568 # pointer to row-wise quantities in value-like data
15451569
1546- D += (
1570+ delta += (
15471571 off_b * stride_D_b +
15481572 off_qh * seqlen_q_rounded
15491573 )
15501574
1551- LSE += (
1575+ lse += (
15521576 off_b * stride_lse_b +
15531577 off_qh * seqlen_q_rounded
15541578 )
15551579
15561580 start_n = tl .program_id (2 )
15571581
1558- if IS_CAUSAL :
1582+ if IS_CAUSAL or IS_SLIDING :
15591583 backward_kernel_one_col_block_causal (
15601584 start_n ,
15611585 Q ,
15621586 K ,
15631587 V ,
15641588 kv_block_indices ,
15651589 kv_block_mask ,
1566- DO ,
1590+ do ,
15671591 DQ ,
15681592 DK ,
15691593 DV ,
1570- LSE ,
1571- D ,
1594+ lse ,
1595+ delta ,
15721596 softmax_scale ,
15731597 stride_qm ,
15741598 stride_kn ,
@@ -1593,7 +1617,7 @@ def backward_kernel(
15931617 SEL_BLOCK = SEL_BLOCK ,
15941618 QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS ,
15951619 QUERY_EXPAND_DIM = QUERY_EXPAND_DIM ,
1596- SLIDING = SLIDING
1620+ SLIDING = IS_SLIDING
15971621 )
15981622 else :
15991623 backward_kernel_one_col_block_sparse (
@@ -1604,12 +1628,12 @@ def backward_kernel(
16041628 kv_block_indices ,
16051629 kv_block_mask ,
16061630 kv_block_grads ,
1607- DO ,
1631+ do ,
16081632 DQ ,
16091633 DK ,
16101634 DV ,
1611- LSE ,
1612- D ,
1635+ lse ,
1636+ delta ,
16131637 softmax_scale ,
16141638 stride_qm ,
16151639 stride_kn ,
@@ -1663,6 +1687,9 @@ def native_sparse_attn_backward(
16631687 if not is_contiguous (do ):
16641688 do = do .contiguous ()
16651689
1690+ if not is_contiguous (do_slide ):
1691+ do_slide = do_slide .contiguous ()
1692+
16661693 batch , q_heads , seqlen_q , dim = q .shape
16671694
16681695 _ , kv_heads , seqlen_k , _ = k .shape
@@ -1740,7 +1767,7 @@ def native_sparse_attn_backward(
17401767 )
17411768
17421769 grid = lambda META : (
1743- num_sel_fine_blocks + int (include_block_causal ) ,
1770+ int ( include_block_causal ) + int (sliding ) + num_sel_fine_blocks ,
17441771 batch * kv_heads ,
17451772 triton .cdiv (seqlen_k , META ['BLOCK' ])
17461773 )
@@ -1758,7 +1785,6 @@ def native_sparse_attn_backward(
17581785 dv ,
17591786 lse ,
17601787 delta ,
1761- slide_out ,
17621788 do_slide ,
17631789 slide_lse ,
17641790 slide_delta ,
@@ -1898,6 +1924,8 @@ def backward(self, ctx, do, do_slide, _, __):
18981924 ) = ctx ._saved_variables
18991925
19001926 do = do .half ()
1927+ do_slide = do_slide .half ()
1928+
19011929 dq = torch .zeros (q .shape , dtype = torch .float32 , device = device )
19021930 dk = torch .zeros (k .shape , dtype = torch .float32 , device = device )
19031931 dv = torch .zeros (v .shape , dtype = torch .float32 , device = device )
@@ -1914,7 +1942,8 @@ def backward(self, ctx, do, do_slide, _, __):
19141942 block_size = block_size ,
19151943 include_block_causal = include_block_causal ,
19161944 return_sel_grads = return_sel_grads ,
1917- block_dk_dv_use_dot = block_dk_dv_use_dot
1945+ block_dk_dv_use_dot = block_dk_dv_use_dot ,
1946+ sliding = return_sliding_window_out
19181947 )
19191948
19201949 ret_sel_grads = None
0 commit comments