Skip to content

Commit 32da919

Browse files
committed
fix a bunch of striding issues, it all works now besides for allowing each query head to pay attention to different segments of kv, finish that tomorrow
1 parent 04603d9 commit 32da919

File tree

3 files changed

+24
-10
lines changed

3 files changed

+24
-10
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def forward_kernel(
8282
stride_kvbl_b,
8383
stride_kvbl_h,
8484
stride_kvbl_m,
85-
nheads,
85+
stride_lse_b,
86+
kv_heads,
8687
seqlen_q,
8788
seqlen_k,
8889
seqlen_q_rounded,
@@ -100,9 +101,9 @@ def forward_kernel(
100101
):
101102
start_m = tl.program_id(0)
102103
off_hb = tl.program_id(1)
103-
off_b = off_hb // nheads
104104

105-
off_h = off_hb % nheads
105+
off_b = off_hb // kv_heads
106+
off_h = off_hb % kv_heads
106107

107108
offs_qh = off_h * QUERY_HEAD_GROUPS + tl.arange(0, QUERY_HEAD_GROUPS)
108109

@@ -142,6 +143,7 @@ def forward_kernel(
142143

143144
lse_ptrs = (
144145
Lse +
146+
off_b * stride_lse_b +
145147
offs_qh[:, None] * seqlen_q_rounded +
146148
offs_m[None, :]
147149
)
@@ -348,7 +350,7 @@ def forward_kernel(
348350
# write back lse
349351

350352
lse_i = lse_i.reshape([QUERY_HEAD_GROUPS, BLOCK])
351-
tl.store(lse_ptrs, lse_i)
353+
tl.store(lse_ptrs, lse_i, mask = offs_m[None, :] < seqlen_q)
352354

353355
# write to output
354356

@@ -429,7 +431,8 @@ def native_sparse_attn_forward(
429431
kv_block_indices.stride(0),
430432
kv_block_indices.stride(1),
431433
kv_block_indices.stride(2),
432-
nheads,
434+
lse.stride(0),
435+
kv_heads,
433436
seqlen_q,
434437
seqlen_k,
435438
seqlen_q_rounded,
@@ -591,7 +594,6 @@ def backward_kernel_one_col_block(
591594
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
592595
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
593596

594-
595597
q_ptrs = (
596598
Q +
597599
offs_g[:, None, None] * stride_qh +
@@ -956,6 +958,8 @@ def backward_kernel(
956958
stride_kvbl_b,
957959
stride_kvbl_h,
958960
stride_kvbl_m,
961+
stride_lse_b,
962+
stride_D_b,
959963
kv_heads,
960964
seqlen_q,
961965
seqlen_k,
@@ -993,8 +997,16 @@ def backward_kernel(
993997
kv_block_mask += off_b * stride_kvbl_b + off_h * stride_kvbl_h
994998

995999
# pointer to row-wise quantities in value-like data
996-
D += off_hb * QUERY_HEAD_GROUPS * seqlen_q_rounded
997-
LSE += off_hb * QUERY_HEAD_GROUPS * seqlen_q_rounded
1000+
1001+
D += (
1002+
off_b * stride_D_b +
1003+
off_h * QUERY_HEAD_GROUPS * seqlen_q_rounded
1004+
)
1005+
1006+
LSE += (
1007+
off_b * stride_lse_b +
1008+
off_h * QUERY_HEAD_GROUPS * seqlen_q_rounded
1009+
)
9981010

9991011
num_block_n = tl.cdiv(seqlen_k, BLOCK)
10001012
for start_n in range(0, num_block_n):
@@ -1137,6 +1149,8 @@ def native_sparse_attn_backward(
11371149
kv_block_indices.stride(0),
11381150
kv_block_indices.stride(1),
11391151
kv_block_indices.stride(2),
1152+
lse.stride(0),
1153+
delta.stride(0),
11401154
kv_heads,
11411155
seqlen_q,
11421156
seqlen_k,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.50"
3+
version = "0.0.51"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

train_triton_nsa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
COMPRESS_BLOCK_SIZE = 16
4444

4545
FINE_BLOCK_SIZE = 16
46-
NUM_FINE_SELECTED = 0
46+
NUM_FINE_SELECTED = 1
4747

4848
INTERPOLATED_IMPORTANCE_SCORE = False
4949
USE_DIFF_TOPK = True

0 commit comments

Comments
 (0)