Skip to content

Commit 338ffb3

Browse files
committed
let triton kernels see the block indices and mask
1 parent c491409 commit 338ffb3

File tree

1 file changed

+44
-16
lines changed

1 file changed

+44
-16
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def _fwd_kernel(
5959
Q,
6060
K,
6161
V,
62+
KV_block_indices,
63+
KV_block_mask,
6264
Out,
6365
M,
6466
Lse,
@@ -87,6 +89,7 @@ def _fwd_kernel(
8789
EVEN_N: tl.constexpr,
8890
EVEN_HEADDIM: tl.constexpr,
8991
BLOCK: tl.constexpr,
92+
NUM_SEL_KV_BLOCKS: tl.constexpr
9093
):
9194
start_m = tl.program_id(0)
9295
off_hb = tl.program_id(1)
@@ -243,15 +246,18 @@ def flash_attn_forward(
243246
q,
244247
k,
245248
v,
246-
indices,
247-
mask,
249+
kv_block_indices,
250+
kv_block_mask,
248251
block_size = 128
249252
):
250253
q, k, v = [x if is_contiguous(x) else x.contiguous() for x in (q, k, v)]
251254

252255
batch, seqlen_q, nheads, dim = q.shape
253256
_, seqlen_k, _, _ = k.shape
254257

258+
num_selected_fine_blocks = kv_block_indices.shape[-1]
259+
assert kv_block_indices.shape == kv_block_mask.shape
260+
255261
assert k.shape == (batch, seqlen_k, nheads, dim)
256262
assert v.shape == (batch, seqlen_k, nheads, dim)
257263
assert dim <= 128, "only support head dimensions up to 128"
@@ -277,6 +283,8 @@ def flash_attn_forward(
277283
q,
278284
k,
279285
v,
286+
kv_block_indices,
287+
kv_block_mask,
280288
o,
281289
m,
282290
lse,
@@ -302,6 +310,7 @@ def flash_attn_forward(
302310
seqlen_k // 32,
303311
BLOCK_HEADDIM,
304312
BLOCK = block_size,
313+
NUM_SEL_KV_BLOCKS = num_selected_fine_blocks,
305314
num_warps = num_warps,
306315
num_stages = 1,
307316
)
@@ -398,6 +407,8 @@ def _bwd_kernel_one_col_block(
398407
Q,
399408
K,
400409
V,
410+
kv_block_indices,
411+
kv_block_mask,
401412
DO,
402413
DQ,
403414
DK,
@@ -421,6 +432,7 @@ def _bwd_kernel_one_col_block(
421432
EVEN_N: tl.constexpr,
422433
EVEN_HEADDIM: tl.constexpr,
423434
BLOCK: tl.constexpr,
435+
NUM_SEL_KV_BLOCKS: tl.constexpr
424436
):
425437
# We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
426438
begin_m = ((start_n * BLOCK) // BLOCK) * BLOCK
@@ -654,6 +666,8 @@ def _bwd_kernel(
654666
Q,
655667
K,
656668
V,
669+
kv_block_indices,
670+
kv_block_mask,
657671
DO,
658672
DQ,
659673
DK,
@@ -695,6 +709,7 @@ def _bwd_kernel(
695709
EVEN_N: tl.constexpr,
696710
EVEN_HEADDIM: tl.constexpr,
697711
BLOCK: tl.constexpr,
712+
NUM_SEL_KV_BLOCKS: tl.constexpr
698713
):
699714
off_hb = tl.program_id(1)
700715
off_b = off_hb // nheads
@@ -718,6 +733,8 @@ def _bwd_kernel(
718733
Q,
719734
K,
720735
V,
736+
kv_block_indices,
737+
kv_block_mask,
721738
DO,
722739
DQ,
723740
DK,
@@ -735,12 +752,13 @@ def _bwd_kernel(
735752
seqlen_q,
736753
seqlen_k,
737754
headdim,
738-
ATOMIC_ADD=False,
739-
BLOCK_HEADDIM=BLOCK_HEADDIM,
740-
EVEN_M=EVEN_M,
741-
EVEN_N=EVEN_N,
742-
EVEN_HEADDIM=EVEN_HEADDIM,
743-
BLOCK=BLOCK,
755+
ATOMIC_ADD = False,
756+
BLOCK_HEADDIM = BLOCK_HEADDIM,
757+
EVEN_M = EVEN_M,
758+
EVEN_N = EVEN_N,
759+
EVEN_HEADDIM = EVEN_HEADDIM,
760+
BLOCK = BLOCK,
761+
NUM_SEL_KV_BLOCKS = NUM_SEL_KV_BLOCKS
744762
)
745763
else:
746764
start_n = tl.program_id(0)
@@ -749,6 +767,8 @@ def _bwd_kernel(
749767
Q,
750768
K,
751769
V,
770+
kv_block_indices,
771+
kv_block_mask,
752772
DO,
753773
DQ,
754774
DK,
@@ -766,19 +786,20 @@ def _bwd_kernel(
766786
seqlen_q,
767787
seqlen_k,
768788
headdim,
769-
ATOMIC_ADD=True,
770-
BLOCK_HEADDIM=BLOCK_HEADDIM,
771-
EVEN_M=EVEN_M,
772-
EVEN_N=EVEN_N,
773-
EVEN_HEADDIM=EVEN_HEADDIM,
774-
BLOCK=BLOCK,
789+
ATOMIC_ADD = True,
790+
BLOCK_HEADDIM = BLOCK_HEADDIM,
791+
EVEN_M = EVEN_M,
792+
EVEN_N = EVEN_N,
793+
EVEN_HEADDIM = EVEN_HEADDIM,
794+
BLOCK = BLOCK,
795+
NUM_SEL_KV_BLOCKS = NUM_SEL_KV_BLOCKS
775796
)
776797

777798
def flash_attn_backward(
778799
do,
779800
q, k, v,
780-
indices,
781-
mask,
801+
kv_block_indices,
802+
kv_block_mask,
782803
o,
783804
lse,
784805
dq, dk, dv,
@@ -790,6 +811,10 @@ def flash_attn_backward(
790811

791812
batch, seqlen_q, nheads, dim = q.shape
792813
_, seqlen_k, _, _ = k.shape
814+
815+
num_sel_fine_blocks = kv_block_indices.shape[-1]
816+
assert kv_block_indices.shape == kv_block_mask.shape
817+
793818
# assert d in {16, 32, 64, 128}
794819
assert dim <= 128
795820
seqlen_q_rounded = round_up_multiple(seqlen_q, TRITON_BLOCK_SIZE)
@@ -834,6 +859,8 @@ def flash_attn_backward(
834859
q,
835860
k,
836861
v,
862+
kv_block_indices,
863+
kv_block_mask,
837864
do,
838865
dq_accum,
839866
dk,
@@ -873,6 +900,7 @@ def flash_attn_backward(
873900
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
874901
BLOCK_HEADDIM,
875902
BLOCK = block_size,
903+
NUM_SEL_KV_BLOCKS = num_sel_fine_blocks,
876904
SEQUENCE_PARALLEL = False,
877905
EVEN_M = (seqlen_q % block_size) == 0,
878906
EVEN_N = (seqlen_k % block_size) == 0,

0 commit comments

Comments
 (0)