Skip to content

Commit f6515e0

Browse files
committed
dq down
1 parent af261a6 commit f6515e0

File tree

3 files changed

+26
-51
lines changed

3 files changed

+26
-51
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 25 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -676,54 +676,21 @@ def backward_kernel_one_col_block(
676676
EVEN_M & EVEN_HEADDIM
677677
): # Otherewise there's a race condition when BIAS_TYPE='matrix'
678678
tl.debug_barrier()
679-
if not ATOMIC_ADD:
680-
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
681-
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
682-
dq += tl.dot(ds, k)
683-
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
684-
else:
685-
if EVEN_HEADDIM:
686-
dq = tl.load(
687-
dq_ptrs,
688-
mask=offs_m[:, None] < seqlen_q,
689-
other=0.0,
690-
eviction_policy="evict_last",
691-
)
692-
dq += tl.dot(ds, k)
693-
tl.store(
694-
dq_ptrs,
695-
dq,
696-
mask=offs_m[:, None] < seqlen_q,
697-
eviction_policy="evict_last",
698-
)
699-
else:
700-
dq = tl.load(
701-
dq_ptrs,
702-
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
703-
other=0.0,
704-
eviction_policy="evict_last",
705-
)
706-
dq += tl.dot(ds, k)
707-
tl.store(
708-
dq_ptrs,
709-
dq,
710-
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
711-
eviction_policy="evict_last",
712-
)
713-
else: # If we're parallelizing across the seqlen_k dimension
714-
dq = tl.dot(ds, k)
715-
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
716-
tl.atomic_add(dq_ptrs, dq, sem = 'relaxed')
679+
680+
dq = tl.dot(ds, k)
681+
682+
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
683+
tl.atomic_add(dq_ptrs, dq, sem = 'relaxed')
684+
else:
685+
if EVEN_HEADDIM:
686+
tl.atomic_add(dq_ptrs, dq, mask=offs_m[:, None] < seqlen_q, sem = 'relaxed')
717687
else:
718-
if EVEN_HEADDIM:
719-
tl.atomic_add(dq_ptrs, dq, mask=offs_m[:, None] < seqlen_q, sem = 'relaxed')
720-
else:
721-
tl.atomic_add(
722-
dq_ptrs,
723-
dq,
724-
mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
725-
sem = 'relaxed',
726-
)
688+
tl.atomic_add(
689+
dq_ptrs,
690+
dq,
691+
mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
692+
sem = 'relaxed',
693+
)
727694

728695
# handle kv block indices using atomic adds for starters, todo: swap dq and dk/dv loops at some point, semi big refactor
729696

@@ -765,8 +732,8 @@ def backward_kernel_one_col_block(
765732
q_expanded = tl.expand_dims(q, 1)
766733
q_expanded = tl.broadcast_to(q_expanded, (BLOCK, 16, BLOCK_HEADDIM))
767734

768-
block_k = tl.permute(block_k, (0, 2, 1))
769-
block_qk = tl.dot(q_expanded, block_k)
735+
block_k_permuted = tl.permute(block_k, (0, 2, 1))
736+
block_qk = tl.dot(q_expanded, block_k_permuted)
770737

771738
qk = tl.sum(block_qk, 1) / 16.
772739
qk += tl.where(block_masks[:, None], 0, float("-inf"))
@@ -800,6 +767,15 @@ def backward_kernel_one_col_block(
800767

801768
tl.atomic_add(block_dk_ptrs, block_dk, sem = 'relaxed')
802769

770+
# block dq
771+
772+
ds_expanded = tl.expand_dims(ds, 1)
773+
ds_expanded = tl.broadcast_to(ds_expanded, (BLOCK, 16, BLOCK))
774+
block_dq = tl.dot(ds_expanded, block_k)
775+
block_dq = tl.sum(block_dq, 1) / 16
776+
777+
tl.atomic_add(dq_ptrs, block_dq, sem = 'relaxed')
778+
803779
# # increment pointers
804780
# dq_ptrs += BLOCK * stride_dqm
805781
# q_ptrs += BLOCK * stride_qm

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.47"
3+
version = "0.0.48"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

test_triton_nsa.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,5 @@ def regular_attend(
105105
assert torch.allclose(out, nsa_out, atol = 1e-2)
106106

107107
assert torch.allclose(nv.grad, rv.grad, atol = 1e-2)
108-
print((nk.grad - rk.grad).abs().amax())
109108
assert torch.allclose(nk.grad, rk.grad, atol = 1e-2)
110109
assert torch.allclose(nq.grad, rq.grad, atol = 1e-2)

0 commit comments

Comments
 (0)