Skip to content

Commit 2f14297

Browse files
committed
cleanup triton flash attn
1 parent b1dee31 commit 2f14297

File tree

2 files changed

+79
-124
lines changed

2 files changed

+79
-124
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 43 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# forward is modified to return unnormalized accumulation, row maxes, row lse - reduced over passed rings
44
# both forwards and backwards is modified to allow for masking out the diagonal for striped ring attention
55

6+
from functools import partial
7+
import math
68
from math import ceil
79

810
import torch
@@ -82,12 +84,6 @@ def _fwd_kernel(
8284
CACHE_KEY_SEQLEN_Q,
8385
CACHE_KEY_SEQLEN_K,
8486
HAS_BIAS: tl.constexpr,
85-
IS_CAUSAL: tl.constexpr,
86-
CAUSAL_MASK_DIAGONAL: tl.constexpr,
87-
LOAD_ACCUMULATED: tl.constexpr,
88-
RETURN_NORMALIZED_OUTPUT: tl.constexpr,
89-
SOFTCLAMP_QK_SIM: tl.constexpr,
90-
SOFTCLAMP_VALUE: tl.constexpr,
9187
BLOCK_HEADDIM: tl.constexpr,
9288
EVEN_M: tl.constexpr,
9389
EVEN_N: tl.constexpr,
@@ -121,19 +117,13 @@ def _fwd_kernel(
121117

122118
m_ptrs = M + off_hb * seqlen_q_rounded + offs_m
123119

124-
if LOAD_ACCUMULATED:
125-
m_i = tl.load(m_ptrs)
126-
else:
127-
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
120+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
128121

129122
# load lse
130123

131124
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
132125

133-
if LOAD_ACCUMULATED:
134-
lse_i = tl.load(lse_ptrs)
135-
else:
136-
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
126+
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
137127

138128
# load accumualted output
139129

@@ -146,23 +136,7 @@ def _fwd_kernel(
146136
+ (offs_m[:, None] * stride_om + offs_d[None, :])
147137
)
148138

149-
if LOAD_ACCUMULATED:
150-
if EVEN_M:
151-
if EVEN_HEADDIM:
152-
acc_o = tl.load(out_ptrs)
153-
else:
154-
acc_o = tl.load(out_ptrs, mask=offs_d[None, :] < headdim)
155-
else:
156-
if EVEN_HEADDIM:
157-
acc_o = tl.load(out_ptrs, mask=offs_m[:, None] < seqlen_q)
158-
else:
159-
acc_o = tl.load(
160-
out_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
161-
)
162-
163-
acc_o = acc_o.to(tl.float32)
164-
else:
165-
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
139+
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
166140

167141
# load queries, keys, values
168142

@@ -179,7 +153,7 @@ def _fwd_kernel(
179153
q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
180154
)
181155

182-
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
156+
end_n = tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
183157
for start_n in range(0, end_n, BLOCK_N):
184158
start_n = tl.multiple_of(start_n, BLOCK_N)
185159

@@ -204,21 +178,10 @@ def _fwd_kernel(
204178
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
205179
qk += tl.dot(q, tl.trans(k))
206180

207-
if SOFTCLAMP_QK_SIM:
208-
effective_softclamp_value = SOFTCLAMP_VALUE / softmax_scale
209-
qk /= effective_softclamp_value
210-
qk = libdevice.tanh(qk)
211-
qk *= effective_softclamp_value
212-
213181
if not EVEN_N:
214182
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
215183

216-
if IS_CAUSAL:
217-
if CAUSAL_MASK_DIAGONAL:
218-
# needed for stripe attention
219-
qk += tl.where(offs_m[:, None] > (start_n + offs_n)[None, :], 0, float("-inf"))
220-
else:
221-
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
184+
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
222185

223186
if HAS_BIAS:
224187
if EVEN_N:
@@ -270,9 +233,8 @@ def _fwd_kernel(
270233
l_i_new = tl.exp(lse_i - m_ij) + l_ij
271234
lse_i = m_ij + tl.log(l_i_new)
272235

273-
if RETURN_NORMALIZED_OUTPUT:
274-
acc_o_scale = tl.exp(m_i - lse_i)
275-
acc_o = acc_o * acc_o_scale[:, None]
236+
acc_o_scale = tl.exp(m_i - lse_i)
237+
acc_o = acc_o * acc_o_scale[:, None]
276238

277239
# offsets for m and lse
278240

@@ -283,9 +245,6 @@ def _fwd_kernel(
283245

284246
tl.store(lse_ptrs, lse_i)
285247

286-
if not RETURN_NORMALIZED_OUTPUT:
287-
tl.store(m_ptrs, m_i)
288-
289248
# write to output
290249

291250
if EVEN_M:
@@ -306,27 +265,14 @@ def flash_attn_forward(
306265
k,
307266
v,
308267
bias = None,
309-
causal = False,
310268
o = None,
311269
m = None,
312270
lse = None,
313271
softmax_scale = None,
314-
causal_mask_diagonal = False,
315-
return_normalized_output = False,
316-
load_accumulated = True,
317-
softclamp_qk_sim = False,
318-
softclamp_value = 50.,
319-
head_first_dim = False,
320272
remove_padding = False
321273
):
322274
q, k, v = [x if is_contiguous(x) else x.contiguous() for x in (q, k, v)]
323275

324-
if head_first_dim:
325-
q, k, v = tuple(rearrange(t, 'b h n d -> b n h d') for t in (q, k, v))
326-
327-
if exists(o):
328-
o = rearrange(o, 'b h n d -> b n h d')
329-
330276
batch, seqlen_q, nheads, d = q.shape
331277
_, seqlen_k, _, _ = k.shape
332278

@@ -360,17 +306,14 @@ def flash_attn_forward(
360306

361307
if not exists(lse):
362308
max_neg_value = -torch.finfo(torch.float32).max
363-
init_fn = partial(torch.full, fill_value = max_neg_value) if load_accumulated else torch.empty
364-
lse = init_fn((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
309+
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
365310

366311
if not exists(m):
367312
max_neg_value = -torch.finfo(torch.float32).max
368-
init_fn = partial(torch.full, fill_value = max_neg_value) if load_accumulated else torch.empty
369-
m = init_fn((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
313+
m = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
370314

371315
if not exists(o):
372-
init_fn = torch.zeros_like if load_accumulated else torch.empty_like
373-
o = init_fn(q)
316+
o = torch.empty_like(q)
374317

375318
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
376319
BLOCK = 128
@@ -407,27 +350,17 @@ def flash_attn_forward(
407350
seqlen_q // 32,
408351
seqlen_k // 32,
409352
has_bias,
410-
causal,
411-
causal_mask_diagonal,
412-
load_accumulated,
413-
return_normalized_output,
414-
softclamp_qk_sim,
415-
softclamp_value,
416353
BLOCK_HEADDIM,
417354
BLOCK_M = BLOCK,
418355
BLOCK_N = BLOCK,
419356
num_warps = num_warps,
420357
num_stages = 1,
421358
)
422359

423-
if head_first_dim:
424-
o = rearrange(o, 'b n h d -> b h n d')
425-
426360
if remove_padding:
427-
m = m[..., :seqlen_q]
428361
lse = lse[..., :seqlen_q]
429362

430-
return o, m, lse
363+
return o, lse
431364

432365
@triton.jit
433366
def _bwd_preprocess_do_o_dot(
@@ -533,10 +466,6 @@ def _bwd_kernel_one_col_block(
533466
headdim,
534467
ATOMIC_ADD: tl.constexpr,
535468
BIAS_TYPE: tl.constexpr,
536-
IS_CAUSAL: tl.constexpr,
537-
CAUSAL_MASK_DIAGONAL: tl.constexpr,
538-
SOFTCLAMP_QK_SIM: tl.constexpr,
539-
SOFTCLAMP_VALUE: tl.constexpr,
540469
BLOCK_HEADDIM: tl.constexpr,
541470
EVEN_M: tl.constexpr,
542471
EVEN_N: tl.constexpr,
@@ -545,7 +474,7 @@ def _bwd_kernel_one_col_block(
545474
BLOCK_N: tl.constexpr,
546475
):
547476
# We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
548-
begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
477+
begin_m = ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
549478
# initialize row/col offsets
550479
offs_qm = begin_m + tl.arange(0, BLOCK_M)
551480
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
@@ -627,22 +556,11 @@ def _bwd_kernel_one_col_block(
627556
# recompute p = softmax(qk, dim=-1).T
628557
qk = tl.dot(q, tl.trans(k))
629558

630-
if SOFTCLAMP_QK_SIM:
631-
effective_softclamp_value = SOFTCLAMP_VALUE / softmax_scale
632-
qk /= effective_softclamp_value
633-
qk = libdevice.tanh(qk)
634-
dtanh = 1. - qk * qk
635-
qk *= effective_softclamp_value
636-
637559
# Trying to combine the two masks seem to make the result wrong
638560
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
639561
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
640-
if IS_CAUSAL:
641-
if CAUSAL_MASK_DIAGONAL:
642-
# needed for stripe attention
643-
qk = tl.where(offs_m_curr[:, None] > (offs_n[None, :]), qk, float("-inf"))
644-
else:
645-
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
562+
563+
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
646564

647565
if BIAS_TYPE != "none":
648566
tl.debug_barrier() # Race condition otherwise
@@ -714,9 +632,6 @@ def _bwd_kernel_one_col_block(
714632
# for BLOCK_HEADDIM=128
715633
ds = (p * (dp - Di[:, None]) * softmax_scale)
716634

717-
if SOFTCLAMP_QK_SIM:
718-
ds *= dtanh
719-
720635
ds = ds.to(q.dtype)
721636

722637
# compute dk = dot(ds.T, q)
@@ -823,7 +738,7 @@ def init_to_zero(name):
823738
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
824739
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
825740
],
826-
key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM"],
741+
key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "BLOCK_HEADDIM"],
827742
)
828743
@triton.heuristics(
829744
{
@@ -877,10 +792,6 @@ def _bwd_kernel(
877792
CACHE_KEY_SEQLEN_Q,
878793
CACHE_KEY_SEQLEN_K,
879794
BIAS_TYPE: tl.constexpr,
880-
IS_CAUSAL: tl.constexpr,
881-
CAUSAL_MASK_DIAGONAL: tl.constexpr,
882-
SOFTCLAMP_QK_SIM: tl.constexpr,
883-
SOFTCLAMP_VALUE: tl.constexpr,
884795
BLOCK_HEADDIM: tl.constexpr,
885796
SEQUENCE_PARALLEL: tl.constexpr,
886797
EVEN_M: tl.constexpr,
@@ -934,10 +845,6 @@ def _bwd_kernel(
934845
headdim,
935846
ATOMIC_ADD=False,
936847
BIAS_TYPE=BIAS_TYPE,
937-
IS_CAUSAL=IS_CAUSAL,
938-
CAUSAL_MASK_DIAGONAL = CAUSAL_MASK_DIAGONAL,
939-
SOFTCLAMP_QK_SIM = SOFTCLAMP_QK_SIM,
940-
SOFTCLAMP_VALUE = SOFTCLAMP_VALUE,
941848
BLOCK_HEADDIM=BLOCK_HEADDIM,
942849
EVEN_M=EVEN_M,
943850
EVEN_N=EVEN_N,
@@ -973,10 +880,6 @@ def _bwd_kernel(
973880
headdim,
974881
ATOMIC_ADD=True,
975882
BIAS_TYPE=BIAS_TYPE,
976-
IS_CAUSAL=IS_CAUSAL,
977-
CAUSAL_MASK_DIAGONAL = CAUSAL_MASK_DIAGONAL,
978-
SOFTCLAMP_QK_SIM = SOFTCLAMP_QK_SIM,
979-
SOFTCLAMP_VALUE = SOFTCLAMP_VALUE,
980883
BLOCK_HEADDIM=BLOCK_HEADDIM,
981884
EVEN_M=EVEN_M,
982885
EVEN_N=EVEN_N,
@@ -997,15 +900,12 @@ def flash_attn_backward(
997900
dv,
998901
delta = None,
999902
bias = None,
1000-
causal = False,
1001-
causal_mask_diagonal = False,
1002903
softmax_scale = None,
1003-
softclamp_qk_sim = False,
1004-
softclamp_value = 50.
1005904
):
1006905
# Make sure that the last dimension is contiguous
1007906
if do.stride(-1) != 1:
1008907
do = do.contiguous()
908+
1009909
batch, seqlen_q, nheads, d = q.shape
1010910
_, seqlen_k, _, _ = k.shape
1011911
# assert d in {16, 32, 64, 128}
@@ -1113,10 +1013,6 @@ def flash_attn_backward(
11131013
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
11141014
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
11151015
bias_type,
1116-
causal,
1117-
causal_mask_diagonal,
1118-
softclamp_qk_sim,
1119-
softclamp_value,
11201016
BLOCK_HEADDIM,
11211017
# SEQUENCE_PARALLEL=False,
11221018
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
@@ -1142,10 +1038,33 @@ def forward(
11421038
selected_block_indices,
11431039
num_grouped_queries
11441040
):
1145-
raise NotImplementedError
1041+
fq, fk, fv = tuple(rearrange(t, 'b h n d -> b n h d') for t in (fq, fk, fv))
1042+
1043+
dtype = fq.dtype
1044+
1045+
fq, fk, fv = tuple(t.half() for t in (fq, fk, fv))
1046+
1047+
out, lse = flash_attn_forward(fq, fk, fv)
1048+
1049+
ctx.save_for_backward(fq, fk, fv, out, lse)
1050+
1051+
out = rearrange(out, 'b n h d -> b h n d')
1052+
return out.type(dtype)
11461053

11471054
@classmethod
11481055
def backward(self, ctx, do):
1149-
raise NotImplementedError
1056+
do = rearrange(do, 'b h n d -> b n h d')
1057+
1058+
q, k, v, out, lse = ctx.saved_tensors
1059+
1060+
do = do.half()
1061+
dq = torch.zeros_like(q)
1062+
dk = torch.zeros_like(k)
1063+
dv = torch.zeros_like(v)
1064+
1065+
flash_attn_backward(do, q, k, v, out, lse, dq, dk, dv)
1066+
1067+
dq, dk, dv = tuple(rearrange(t, 'b n h d -> b h n d') for t in (dq, dk, dv))
1068+
return dq, dk, dv, None, None, None
11501069

11511070
native_sparse_attend = NSA.apply

test_triton_nsa.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
from native_sparse_attention_pytorch.triton_native_sparse_attention import native_sparse_attend
3+
4+
from einops import rearrange, einsum
5+
6+
assert torch.cuda.is_available()
7+
8+
def regular_attend(q, k, v):
9+
seq_len, device = q.shape[-2], q.device
10+
scale = q.shape[-1] ** -0.5
11+
12+
sim = einsum(q, k, 'b h i d, b h j d -> b h i j') * scale
13+
causal_mask = torch.ones((seq_len, seq_len), device = device, dtype = torch.bool).triu(1)
14+
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
15+
attn = sim.softmax(dim = -1)
16+
17+
return einsum(attn, v, 'b h i j, b h j d -> b h i d')
18+
19+
q = torch.randn(1, 1024, 4, 64).cuda()
20+
k = torch.randn(1, 1024, 4, 64).cuda()
21+
v = torch.randn(1, 1024, 4, 64).cuda()
22+
23+
rq, rk, rv = tuple(t.clone().requires_grad_() for t in (q, k, v))
24+
nq, nk, nv = tuple(t.clone().requires_grad_() for t in (q, k, v))
25+
26+
out = regular_attend(rq, rk, rv)
27+
out.sum().backward()
28+
29+
nsa_out = native_sparse_attend(nq, nk, nv, 4, None, 1)
30+
nsa_out.sum().backward()
31+
32+
assert torch.allclose(out, nsa_out, atol = 1e-2)
33+
34+
assert torch.allclose(nq.grad, rq.grad, atol = 1e-2)
35+
assert torch.allclose(nk.grad, rk.grad, atol = 1e-2)
36+
assert torch.allclose(nv.grad, rv.grad, atol = 1e-2)

0 commit comments

Comments
 (0)