Skip to content

Commit 69da398

Browse files
committed
last commit for today for this project
1 parent e28f89b commit 69da398

File tree

1 file changed

+46
-17
lines changed

1 file changed

+46
-17
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,7 +1465,6 @@ def backward_kernel(
14651465
DV,
14661466
LSE,
14671467
D,
1468-
SLIDE_O,
14691468
SLIDE_DO,
14701469
SLIDE_LSE,
14711470
SLIDE_D,
@@ -1522,19 +1521,44 @@ def backward_kernel(
15221521
off_h = off_hb % kv_heads
15231522
off_qh = off_h * QUERY_HEAD_GROUPS
15241523

1525-
OFF_SEL_KV_BLOCKS = tl.program_id(0) - int(INCLUDE_BLOCK_CAUSAL)
1526-
IS_CAUSAL = INCLUDE_BLOCK_CAUSAL and tl.program_id(0) == 0
1524+
# determine whether block causal diagonal, sliding, or selected fine kv blocks
1525+
1526+
block_id = tl.program_id(0)
1527+
1528+
IS_CAUSAL = False
1529+
IS_SLIDING = False
1530+
1531+
do = DO
1532+
lse = LSE
1533+
delta = D
1534+
1535+
if INCLUDE_BLOCK_CAUSAL:
1536+
IS_CAUSAL = block_id == 0
1537+
block_id -= 1
1538+
1539+
if SLIDING:
1540+
IS_SLIDING = block_id == 0
1541+
block_id -= 1
1542+
1543+
if IS_SLIDING:
1544+
do = SLIDE_DO
1545+
lse = SLIDE_LSE
1546+
delta = SLIDE_D
1547+
1548+
OFF_SEL_KV_BLOCKS = block_id
15271549

15281550
# offset pointers for batch/head
15291551

15301552
Q += off_b * stride_qb + off_qh * stride_qh
15311553
K += off_b * stride_kb + off_h * stride_kh
15321554
V += off_b * stride_vb + off_h * stride_vh
1533-
DO += off_b * stride_dob + off_qh * stride_doh
1555+
15341556
DQ += off_b * stride_dqb + off_qh * stride_dqh
15351557
DK += off_b * stride_dkb + off_h * stride_dkh
15361558
DV += off_b * stride_dvb + off_h * stride_dvh
15371559

1560+
do += off_b * stride_dob + off_qh * stride_doh
1561+
15381562
# offset pointers for batch/head for selected kv block related
15391563

15401564
kv_block_indices += off_b * stride_kvbl_b + off_h * stride_kvbl_h
@@ -1543,32 +1567,32 @@ def backward_kernel(
15431567

15441568
# pointer to row-wise quantities in value-like data
15451569

1546-
D += (
1570+
delta += (
15471571
off_b * stride_D_b +
15481572
off_qh * seqlen_q_rounded
15491573
)
15501574

1551-
LSE += (
1575+
lse += (
15521576
off_b * stride_lse_b +
15531577
off_qh * seqlen_q_rounded
15541578
)
15551579

15561580
start_n = tl.program_id(2)
15571581

1558-
if IS_CAUSAL:
1582+
if IS_CAUSAL or IS_SLIDING:
15591583
backward_kernel_one_col_block_causal(
15601584
start_n,
15611585
Q,
15621586
K,
15631587
V,
15641588
kv_block_indices,
15651589
kv_block_mask,
1566-
DO,
1590+
do,
15671591
DQ,
15681592
DK,
15691593
DV,
1570-
LSE,
1571-
D,
1594+
lse,
1595+
delta,
15721596
softmax_scale,
15731597
stride_qm,
15741598
stride_kn,
@@ -1593,7 +1617,7 @@ def backward_kernel(
15931617
SEL_BLOCK = SEL_BLOCK,
15941618
QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS,
15951619
QUERY_EXPAND_DIM = QUERY_EXPAND_DIM,
1596-
SLIDING = SLIDING
1620+
SLIDING = IS_SLIDING
15971621
)
15981622
else:
15991623
backward_kernel_one_col_block_sparse(
@@ -1604,12 +1628,12 @@ def backward_kernel(
16041628
kv_block_indices,
16051629
kv_block_mask,
16061630
kv_block_grads,
1607-
DO,
1631+
do,
16081632
DQ,
16091633
DK,
16101634
DV,
1611-
LSE,
1612-
D,
1635+
lse,
1636+
delta,
16131637
softmax_scale,
16141638
stride_qm,
16151639
stride_kn,
@@ -1663,6 +1687,9 @@ def native_sparse_attn_backward(
16631687
if not is_contiguous(do):
16641688
do = do.contiguous()
16651689

1690+
if not is_contiguous(do_slide):
1691+
do_slide = do_slide.contiguous()
1692+
16661693
batch, q_heads, seqlen_q, dim = q.shape
16671694

16681695
_, kv_heads, seqlen_k, _ = k.shape
@@ -1740,7 +1767,7 @@ def native_sparse_attn_backward(
17401767
)
17411768

17421769
grid = lambda META: (
1743-
num_sel_fine_blocks + int(include_block_causal),
1770+
int(include_block_causal) + int(sliding) + num_sel_fine_blocks,
17441771
batch * kv_heads,
17451772
triton.cdiv(seqlen_k, META['BLOCK'])
17461773
)
@@ -1758,7 +1785,6 @@ def native_sparse_attn_backward(
17581785
dv,
17591786
lse,
17601787
delta,
1761-
slide_out,
17621788
do_slide,
17631789
slide_lse,
17641790
slide_delta,
@@ -1898,6 +1924,8 @@ def backward(self, ctx, do, do_slide, _, __):
18981924
) = ctx._saved_variables
18991925

19001926
do = do.half()
1927+
do_slide = do_slide.half()
1928+
19011929
dq = torch.zeros(q.shape, dtype = torch.float32, device = device)
19021930
dk = torch.zeros(k.shape, dtype = torch.float32, device = device)
19031931
dv = torch.zeros(v.shape, dtype = torch.float32, device = device)
@@ -1914,7 +1942,8 @@ def backward(self, ctx, do, do_slide, _, __):
19141942
block_size = block_size,
19151943
include_block_causal = include_block_causal,
19161944
return_sel_grads = return_sel_grads,
1917-
block_dk_dv_use_dot = block_dk_dv_use_dot
1945+
block_dk_dv_use_dot = block_dk_dv_use_dot,
1946+
sliding = return_sliding_window_out
19181947
)
19191948

19201949
ret_sel_grads = None

0 commit comments

Comments
 (0)