Skip to content

Commit af261a6

Browse files
committed
dk down, dq to go
1 parent 9007ac0 commit af261a6

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,11 +773,33 @@ def backward_kernel_one_col_block(
773773

774774
p = tl.exp(qk * softmax_scale - lse_i[:, None])
775775

776+
# take care of block dv
777+
776778
block_dv = p.to(do.dtype)[:, :, None] * do[:, None, :]
777779
block_dv = tl.where(block_masks[:, None, None], block_dv, 0.)
778780

779781
tl.atomic_add(block_dv_ptrs, block_dv, sem = 'relaxed')
780782

783+
# get dp
784+
785+
do_expanded = tl.expand_dims(do, 1)
786+
do_expanded = tl.broadcast_to(do_expanded, (BLOCK, 16, BLOCK_HEADDIM))
787+
block_v = tl.permute(block_v, (0, 2, 1))
788+
789+
dp = tl.dot(do_expanded, block_v)
790+
dp = tl.sum(dp, 1) / 16.
791+
792+
# ds
793+
794+
ds = (p * (dp - Di[:, None]) * softmax_scale)
795+
ds = ds.to(q.dtype)
796+
797+
# block dk
798+
799+
block_dk = ds[:, :, None] * q[:, None, :]
800+
801+
tl.atomic_add(block_dk_ptrs, block_dk, sem = 'relaxed')
802+
781803
# # increment pointers
782804
# dq_ptrs += BLOCK * stride_dqm
783805
# q_ptrs += BLOCK * stride_qm

test_triton_nsa.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,5 +105,6 @@ def regular_attend(
105105
assert torch.allclose(out, nsa_out, atol = 1e-2)
106106

107107
assert torch.allclose(nv.grad, rv.grad, atol = 1e-2)
108-
assert torch.allclose(nq.grad, rq.grad, atol = 1e-2)
108+
print((nk.grad - rk.grad).abs().amax())
109109
assert torch.allclose(nk.grad, rk.grad, atol = 1e-2)
110+
assert torch.allclose(nq.grad, rq.grad, atol = 1e-2)

0 commit comments

Comments
 (0)