Skip to content

Commit 7a944f1

Browse files
committed
keep cleaning
1 parent ea4fb34 commit 7a944f1

File tree

1 file changed

+15
-26
lines changed

1 file changed

+15
-26
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -240,23 +240,21 @@ def flash_attn_forward(
240240
q,
241241
k,
242242
v,
243-
softmax_scale = None,
244-
remove_padding = False,
245243
block_size = 128
246244
):
247245
q, k, v = [x if is_contiguous(x) else x.contiguous() for x in (q, k, v)]
248246

249-
batch, seqlen_q, nheads, d = q.shape
247+
batch, seqlen_q, nheads, dim = q.shape
250248
_, seqlen_k, _, _ = k.shape
251249

252-
assert k.shape == (batch, seqlen_k, nheads, d)
253-
assert v.shape == (batch, seqlen_k, nheads, d)
254-
assert d <= 128, "FlashAttention only support head dimensions up to 128"
250+
assert k.shape == (batch, seqlen_k, nheads, dim)
251+
assert v.shape == (batch, seqlen_k, nheads, dim)
252+
assert dim <= 128, "only support head dimensions up to 128"
255253
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
256254
assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
257255
assert q.is_cuda and k.is_cuda and v.is_cuda
258256

259-
softmax_scale = default(softmax_scale, d ** -0.5)
257+
softmax_scale = dim ** -0.5
260258

261259
seqlen_q_rounded = ceil(seqlen_q / 128) * 128
262260

@@ -266,8 +264,8 @@ def flash_attn_forward(
266264

267265
o = torch.empty_like(q)
268266

269-
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
270-
num_warps = 4 if d <= 64 else 8
267+
BLOCK_HEADDIM = max(triton.next_power_of_2(dim), 16)
268+
num_warps = 4 if dim <= 64 else 8
271269
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK"]), batch * nheads)
272270

273271
_fwd_kernel[grid](
@@ -294,7 +292,7 @@ def flash_attn_forward(
294292
seqlen_q,
295293
seqlen_k,
296294
seqlen_q_rounded,
297-
d,
295+
dim,
298296
seqlen_q // 32,
299297
seqlen_k // 32,
300298
BLOCK_HEADDIM,
@@ -303,9 +301,6 @@ def flash_attn_forward(
303301
num_stages = 1,
304302
)
305303

306-
if remove_padding:
307-
lse = lse[..., :seqlen_q]
308-
309304
return o, lse
310305

311306
@triton.jit
@@ -352,7 +347,6 @@ def _bwd_preprocess_do_o_dot(
352347
# write-back
353348
tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
354349

355-
356350
@triton.jit
357351
def _bwd_store_dk_dv(
358352
dk_ptrs,
@@ -632,10 +626,6 @@ def _bwd_kernel_one_col_block(
632626
EVEN_HEADDIM=EVEN_HEADDIM,
633627
)
634628

635-
636-
def init_to_zero(name):
637-
return lambda nargs: nargs[name].zero_()
638-
639629
@triton.jit
640630
def _bwd_kernel(
641631
Q,
@@ -771,29 +761,28 @@ def flash_attn_backward(
771761
dq,
772762
dk,
773763
dv,
774-
softmax_scale = None,
775764
block_size = 128
776765
):
777766
# Make sure that the last dimension is contiguous
778767
if do.stride(-1) != 1:
779768
do = do.contiguous()
780769

781-
batch, seqlen_q, nheads, d = q.shape
770+
batch, seqlen_q, nheads, dim = q.shape
782771
_, seqlen_k, _, _ = k.shape
783772
# assert d in {16, 32, 64, 128}
784-
assert d <= 128
773+
assert dim <= 128
785774
seqlen_q_rounded = ceil(seqlen_q / 128) * 128
786775

787776
assert lse.shape == (batch, nheads, seqlen_q_rounded)
788777
assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
789778
assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
790-
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
779+
softmax_scale = dim ** -0.5
791780
# dq_accum = torch.zeros_like(q, dtype=torch.float32)
792781
dq_accum = torch.empty_like(q, dtype=torch.float32)
793782

794783
# delta = torch.zeros_like(lse)
795784

796-
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
785+
BLOCK_HEADDIM = max(triton.next_power_of_2(dim), 16)
797786

798787
delta = torch.empty_like(lse)
799788
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK"]), batch * nheads)
@@ -810,7 +799,7 @@ def flash_attn_backward(
810799
nheads,
811800
seqlen_q,
812801
seqlen_q_rounded,
813-
d,
802+
dim,
814803
BLOCK = block_size,
815804
BLOCK_HEADDIM=BLOCK_HEADDIM,
816805
)
@@ -858,7 +847,7 @@ def flash_attn_backward(
858847
seqlen_q,
859848
seqlen_k,
860849
seqlen_q_rounded,
861-
d,
850+
dim,
862851
seqlen_q // 32,
863852
seqlen_k // 32, # key for triton cache (limit number of compilations)
864853
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
@@ -868,7 +857,7 @@ def flash_attn_backward(
868857
SEQUENCE_PARALLEL = False,
869858
EVEN_M = (seqlen_q % block_size) == 0,
870859
EVEN_N = (seqlen_k % block_size) == 0,
871-
EVEN_HEADDIM = BLOCK_HEADDIM == d
860+
EVEN_HEADDIM = BLOCK_HEADDIM == dim
872861
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
873862
# num_warps=num_warps,
874863
# num_stages=1,

0 commit comments

Comments
 (0)