Skip to content

Commit 9007ac0

Browse files
committed
dv down, dq/dk to go
1 parent ee3745c commit 9007ac0

File tree

2 files changed

+30
-11
lines changed

2 files changed

+30
-11
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -473,18 +473,18 @@ def backward_store_dk_dv(
473473
# if we just call tl.store(dv_ptrs), there's a race condition
474474
if EVEN_N & EVEN_M:
475475
if EVEN_HEADDIM:
476-
tl.store(dv_ptrs, dv)
477-
tl.store(dk_ptrs, dk)
476+
tl.atomic_add(dv_ptrs, dv, sem = 'relaxed')
477+
tl.atomic_add(dk_ptrs, dk, sem = 'relaxed')
478478
else:
479-
tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
480-
tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
479+
tl.atomic_add(dv_ptrs, dv, mask=offs_d[None, :] < headdim, sem = 'relaxed')
480+
tl.atomic_add(dk_ptrs, dk, mask=offs_d[None, :] < headdim, sem = 'relaxed')
481481
else:
482482
if EVEN_HEADDIM:
483-
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
484-
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
483+
tl.atomic_add(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k, sem = 'relaxed')
484+
tl.atomic_add(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k, sem = 'relaxed')
485485
else:
486-
tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
487-
tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
486+
tl.atomic_add(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), sem = 'relaxed')
487+
tl.atomic_add(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), sem = 'relaxed')
488488

489489

490490
@triton.jit
@@ -751,6 +751,14 @@ def backward_kernel_one_col_block(
751751
V + blocks_offs_n[:, :, None] * stride_vn + offs_d[None, None, :]
752752
)
753753

754+
block_dv_ptrs = (
755+
DV + blocks_offs_n[:, :, None] * stride_dvn + offs_d[None, None, :]
756+
)
757+
758+
block_dk_ptrs = (
759+
DK + blocks_offs_n[:, :, None] * stride_dkn + offs_d[None, None, :]
760+
)
761+
754762
block_k = tl.load(block_k_ptrs)
755763
block_v = tl.load(block_v_ptrs)
756764

@@ -765,6 +773,11 @@ def backward_kernel_one_col_block(
765773

766774
p = tl.exp(qk * softmax_scale - lse_i[:, None])
767775

776+
block_dv = p.to(do.dtype)[:, :, None] * do[:, None, :]
777+
block_dv = tl.where(block_masks[:, None, None], block_dv, 0.)
778+
779+
tl.atomic_add(block_dv_ptrs, block_dv, sem = 'relaxed')
780+
768781
# # increment pointers
769782
# dq_ptrs += BLOCK * stride_dqm
770783
# q_ptrs += BLOCK * stride_qm
@@ -965,6 +978,8 @@ def flash_attn_backward(
965978
softmax_scale = dim ** -0.5
966979

967980
dq_accum = torch.empty_like(q, dtype = torch.float32)
981+
dk_accum = torch.empty_like(k, dtype = torch.float32)
982+
dv_accum = torch.empty_like(v, dtype = torch.float32)
968983

969984
# delta = torch.zeros_like(lse)
970985

@@ -995,6 +1010,7 @@ def flash_attn_backward(
9951010
triton.cdiv(seqlen_k, META["BLOCK"]) if META["SEQUENCE_PARALLEL"] else 1,
9961011
batch * nheads,
9971012
)
1013+
9981014
backward_kernel[grid](
9991015
q,
10001016
k,
@@ -1003,8 +1019,8 @@ def flash_attn_backward(
10031019
kv_block_mask,
10041020
do,
10051021
dq_accum,
1006-
dk,
1007-
dv,
1022+
dk_accum,
1023+
dv_accum,
10081024
lse,
10091025
delta,
10101026
softmax_scale,
@@ -1052,7 +1068,10 @@ def flash_attn_backward(
10521068
# num_warps=num_warps,
10531069
# num_stages=1,
10541070
)
1071+
10551072
dq.copy_(dq_accum)
1073+
dk.copy_(dk_accum)
1074+
dv.copy_(dv_accum)
10561075

10571076
return delta
10581077

test_triton_nsa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,6 @@ def regular_attend(
104104

105105
assert torch.allclose(out, nsa_out, atol = 1e-2)
106106

107+
assert torch.allclose(nv.grad, rv.grad, atol = 1e-2)
107108
assert torch.allclose(nq.grad, rq.grad, atol = 1e-2)
108109
assert torch.allclose(nk.grad, rk.grad, atol = 1e-2)
109-
assert torch.allclose(nv.grad, rv.grad, atol = 1e-2)

0 commit comments

Comments
 (0)