Skip to content

Commit 202b21e

Browse files
committed
save knocking out the forwards for tomorrow
1 parent 6606b67 commit 202b21e

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def _fwd_kernel(
7777
stride_ob,
7878
stride_oh,
7979
stride_om,
80+
stride_kvbl_b,
81+
stride_kvbl_h,
82+
stride_kvbl_m,
8083
nheads,
8184
seqlen_q,
8285
seqlen_k,
@@ -301,6 +304,9 @@ def flash_attn_forward(
301304
o.stride(0),
302305
o.stride(2),
303306
o.stride(1),
307+
kv_block_indices.stride(0),
308+
kv_block_indices.stride(2),
309+
kv_block_indices.stride(1),
304310
nheads,
305311
seqlen_q,
306312
seqlen_k,
@@ -929,6 +935,8 @@ def forward(
929935
fmask,
930936
num_grouped_queries
931937
):
938+
selected_block_indices, fmask = tuple(rearrange(t, 'b h i sel -> b i h sel') for t in (selected_block_indices, fmask))
939+
932940
fq, fk, fv = tuple(rearrange(t, 'b h n d -> b n h d') for t in (fq, fk, fv))
933941

934942
dtype = fq.dtype

0 commit comments

Comments
 (0)