Skip to content

Commit 4d34b8a

Browse files
committed
last commit for today
1 parent 1a6019d commit 4d34b8a

File tree

2 files changed

+19
-15
lines changed

2 files changed

+19
-15
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -677,20 +677,9 @@ def backward_kernel_one_col_block(
677677
): # Otherewise there's a race condition when BIAS_TYPE='matrix'
678678
tl.debug_barrier()
679679

680-
dq = tl.dot(ds, k)
680+
dq = tl.zeros([BLOCK, BLOCK_HEADDIM], dtype = tl.float32)
681681

682-
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
683-
tl.atomic_add(dq_ptrs, dq, sem = 'relaxed')
684-
else:
685-
if EVEN_HEADDIM:
686-
tl.atomic_add(dq_ptrs, dq, mask=offs_m[:, None] < seqlen_q, sem = 'relaxed')
687-
else:
688-
tl.atomic_add(
689-
dq_ptrs,
690-
dq,
691-
mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
692-
sem = 'relaxed',
693-
)
682+
dq += tl.dot(ds, k)
694683

695684
# handle kv block indices using atomic adds for starters, todo: swap dq and dk/dv loops at some point, semi big refactor
696685

@@ -774,7 +763,22 @@ def backward_kernel_one_col_block(
774763
block_dq = tl.dot(ds_expanded, block_k)
775764
block_dq = tl.sum(block_dq, 1) / 16
776765

777-
tl.atomic_add(dq_ptrs, block_dq, sem = 'relaxed')
766+
dq += block_dq
767+
768+
# update dq
769+
770+
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
771+
tl.atomic_add(dq_ptrs, dq, sem = 'relaxed')
772+
else:
773+
if EVEN_HEADDIM:
774+
tl.atomic_add(dq_ptrs, dq, mask=offs_m[:, None] < seqlen_q, sem = 'relaxed')
775+
else:
776+
tl.atomic_add(
777+
dq_ptrs,
778+
dq,
779+
mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
780+
sem = 'relaxed',
781+
)
778782

779783
# # increment pointers
780784
# dq_ptrs += BLOCK * stride_dqm

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.49"
3+
version = "0.0.50"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)