Skip to content

Commit 6e52c4b

Browse files
committed
fixed, resuming with backwards gqa..
1 parent 9d03015 commit 6e52c4b

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def forward_kernel(
144144

145145
lse_ptrs = (
146146
Lse +
147-
(off_hb + offs_lse_qh[:, None]) * seqlen_q_rounded +
147+
offs_qh[:, None] * seqlen_q_rounded +
148148
offs_m[None, :]
149149
)
150150

test_triton_nsa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def regular_attend(
9393

9494
fine_block_size = 16
9595

96-
q = torch.randn(1, 2, 512, 64).cuda()
96+
q = torch.randn(1, 4, 512, 64).cuda()
9797
k = torch.randn(1, 2, 512, 64).cuda()
9898
v = torch.randn(1, 2, 512, 64).cuda()
9999

0 commit comments

Comments
 (0)