Skip to content

Commit c491409

Browse files
committed
more prep
1 parent 4ac117a commit c491409

File tree

2 files changed

+68
-19
lines changed

2 files changed

+68
-19
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,13 @@ def _fwd_kernel(
111111

112112
m_ptrs = M + off_hb * seqlen_q_rounded + offs_m
113113

114-
m_i = tl.zeros([BLOCK], dtype=tl.float32) - float("inf")
114+
m_i = tl.zeros([BLOCK], dtype = tl.float32) - float("inf")
115115

116116
# lse
117117

118118
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
119119

120-
lse_i = tl.zeros([BLOCK], dtype=tl.float32) - float("inf")
120+
lse_i = tl.zeros([BLOCK], dtype = tl.float32) - float("inf")
121121

122122
# output
123123

@@ -130,7 +130,7 @@ def _fwd_kernel(
130130
+ (offs_m[:, None] * stride_om + offs_d[None, :])
131131
)
132132

133-
acc_o = tl.zeros([BLOCK, BLOCK_HEADDIM], dtype=tl.float32)
133+
acc_o = tl.zeros([BLOCK, BLOCK_HEADDIM], dtype = tl.float32)
134134

135135
# load queries, keys, values
136136

@@ -243,6 +243,8 @@ def flash_attn_forward(
243243
q,
244244
k,
245245
v,
246+
indices,
247+
mask,
246248
block_size = 128
247249
):
248250
q, k, v = [x if is_contiguous(x) else x.contiguous() for x in (q, k, v)]
@@ -328,15 +330,20 @@ def _bwd_preprocess_do_o_dot(
328330
off_hb = tl.program_id(1)
329331
off_b = off_hb // nheads
330332
off_h = off_hb % nheads
333+
331334
# initialize offsets
335+
332336
offs_m = start_m * BLOCK + tl.arange(0, BLOCK)
333337
offs_d = tl.arange(0, BLOCK_HEADDIM)
338+
334339
# load
340+
335341
o = tl.load(
336342
Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
337343
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
338344
other=0.0,
339345
).to(tl.float32)
346+
340347
do = tl.load(
341348
DO
342349
+ off_b * stride_dob
@@ -346,8 +353,11 @@ def _bwd_preprocess_do_o_dot(
346353
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
347354
other=0.0,
348355
).to(tl.float32)
356+
349357
delta = tl.sum(o * do, axis=1)
358+
350359
# write-back
360+
351361
tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
352362

353363
@triton.jit
@@ -538,22 +548,31 @@ def _bwd_kernel_one_col_block(
538548
# Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
539549
if not (EVEN_M & EVEN_HEADDIM):
540550
tl.debug_barrier()
551+
541552
dp = tl.dot(do, tl.trans(v))
553+
542554
# There's a race condition for headdim=48
543555
if not EVEN_HEADDIM:
544556
tl.debug_barrier()
557+
545558
# compute ds = p * (dp - delta[:, None])
546559
# Putting the subtraction after the dp matmul (instead of before) is slightly faster
560+
547561
Di = tl.load(D + offs_m)
562+
548563
# Converting ds to q.dtype here reduces register pressure and makes it much faster
549564
# for BLOCK_HEADDIM=128
565+
550566
ds = (p * (dp - Di[:, None]) * softmax_scale)
551567

552568
ds = ds.to(q.dtype)
553569

554570
# compute dk = dot(ds.T, q)
571+
555572
dk += tl.dot(tl.trans(ds), q)
573+
556574
# compute dq
575+
557576
if not (
558577
EVEN_M & EVEN_HEADDIM
559578
): # Otherewise there's a race condition when BIAS_TYPE='matrix'
@@ -613,6 +632,7 @@ def _bwd_kernel_one_col_block(
613632
# do_ptrs += BLOCK * stride_dom
614633

615634
# write-back
635+
616636
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
617637
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
618638
_bwd_store_dk_dv(
@@ -756,14 +776,12 @@ def _bwd_kernel(
756776

757777
def flash_attn_backward(
758778
do,
759-
q,
760-
k,
761-
v,
779+
q, k, v,
780+
indices,
781+
mask,
762782
o,
763783
lse,
764-
dq,
765-
dk,
766-
dv,
784+
dq, dk, dv,
767785
block_size = 128
768786
):
769787
# Make sure that the last dimension is contiguous
@@ -805,7 +823,7 @@ def flash_attn_backward(
805823
seqlen_q_rounded,
806824
dim,
807825
BLOCK = block_size,
808-
BLOCK_HEADDIM=BLOCK_HEADDIM,
826+
BLOCK_HEADDIM = BLOCK_HEADDIM,
809827
)
810828

811829
grid = lambda META: (
@@ -889,7 +907,12 @@ def forward(
889907

890908
fq, fk, fv = tuple(t.half() for t in (fq, fk, fv))
891909

892-
out, lse = flash_attn_forward(fq, fk, fv, block_size = block_size)
910+
out, lse = flash_attn_forward(
911+
fq, fk, fv,
912+
selected_block_indices,
913+
fmask,
914+
block_size = block_size
915+
)
893916

894917
ctx.save_for_backward(fq, fk, fv, selected_block_indices, fmask, out, lse)
895918
ctx._saved_variables = (block_size,)
@@ -901,7 +924,7 @@ def forward(
901924
def backward(self, ctx, do):
902925
do = rearrange(do, 'b h n d -> b n h d')
903926

904-
q, k, v, kv_indices, mask, out, lse = ctx.saved_tensors
927+
q, k, v, sel_block_indices, mask, out, lse = ctx.saved_tensors
905928

906929
(
907930
block_size,
@@ -912,7 +935,12 @@ def backward(self, ctx, do):
912935
dk = torch.zeros_like(k)
913936
dv = torch.zeros_like(v)
914937

915-
flash_attn_backward(do, q, k, v, out, lse, dq, dk, dv, block_size = block_size)
938+
flash_attn_backward(
939+
do, q, k, v,
940+
sel_block_indices, mask,
941+
out, lse, dq, dk, dv,
942+
block_size = block_size
943+
)
916944

917945
dq, dk, dv = tuple(rearrange(t, 'b n h d -> b h n d') for t in (dq, dk, dv))
918946
return dq, dk, dv, None, None, None, None

test_triton_nsa.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
import torch
22
from native_sparse_attention_pytorch.triton_native_sparse_attention import native_sparse_attend
33

4+
import einx
45
from einops import rearrange, einsum
56

67
assert torch.cuda.is_available()
78

89
def exists(v):
910
return v is not None
1011

11-
def regular_attend(q, k, v, block_size = None):
12+
def regular_attend(
13+
q, k, v,
14+
indices,
15+
mask,
16+
block_size = None,
17+
):
1218
if exists(block_size):
1319
w = q.shape[-2] // block_size
1420
q, k, v = tuple(rearrange(t, 'b h (w n) d -> b (h w) n d', n = block_size) for t in (q, k, v))
@@ -28,19 +34,34 @@ def regular_attend(q, k, v, block_size = None):
2834

2935
return out
3036

31-
q = torch.randn(1, 4, 1024, 64).cuda()
32-
k = torch.randn(1, 4, 1024, 64).cuda()
33-
v = torch.randn(1, 4, 1024, 64).cuda()
37+
# mock inputs
38+
39+
fine_block_size = 64
40+
41+
q = torch.randn(1, 2, 512, 64).cuda()
42+
k = torch.randn(1, 2, 512, 64).cuda()
43+
v = torch.randn(1, 2, 512, 64).cuda()
44+
45+
indices = torch.zeros(1, 2, 512, 1).long().cuda()
46+
mask = torch.zeros(1, 2, 512, 1).bool().cuda()
47+
48+
# both regular and nsa pathways `r` and `n`
3449

3550
rq, rk, rv = tuple(t.clone().requires_grad_() for t in (q, k, v))
3651
nq, nk, nv = tuple(t.clone().requires_grad_() for t in (q, k, v))
3752

38-
out = regular_attend(rq, rk, rv, block_size = 64)
53+
# regular forwards and backwards
54+
55+
out = regular_attend(rq, rk, rv, indices, mask, block_size = fine_block_size)
3956
out.sum().backward()
4057

41-
nsa_out = native_sparse_attend(nq, nk, nv, 64, None, None, 1)
58+
# triton nsa forwards and backwards
59+
60+
nsa_out = native_sparse_attend(nq, nk, nv, fine_block_size, indices, mask, 1)
4261
nsa_out.sum().backward()
4362

63+
# asserts
64+
4465
assert torch.allclose(out, nsa_out, atol = 1e-2)
4566

4667
assert torch.allclose(nq.grad, rq.grad, atol = 1e-2)

0 commit comments

Comments
 (0)