Skip to content

Commit 19d0e66

Browse files
committed
more driveby cleaning
1 parent 7a944f1 commit 19d0e66

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ def exists(v):
1818
def default(val, d):
1919
return val if exists(val) else d
2020

21+
def round_up_multiple(n, mult):
22+
return ceil(n / mult) * mult
23+
2124
def is_contiguous(x: Tensor):
2225
return x.stride(-1) == 1
2326

@@ -42,6 +45,10 @@ def is_contiguous(x: Tensor):
4245
import triton.language as tl
4346
from triton.language.extra import libdevice
4447

48+
# constants
49+
50+
TRITON_BLOCK_SIZE = 128
51+
4552
# kernels
4653

4754
@triton.heuristics(
@@ -256,7 +263,7 @@ def flash_attn_forward(
256263

257264
softmax_scale = dim ** -0.5
258265

259-
seqlen_q_rounded = ceil(seqlen_q / 128) * 128
266+
seqlen_q_rounded = round_up_multiple(seqlen_q, TRITON_BLOCK_SIZE)
260267

261268
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
262269

@@ -764,18 +771,18 @@ def flash_attn_backward(
764771
block_size = 128
765772
):
766773
# Make sure that the last dimension is contiguous
767-
if do.stride(-1) != 1:
774+
if not is_contiguous(do):
768775
do = do.contiguous()
769776

770777
batch, seqlen_q, nheads, dim = q.shape
771778
_, seqlen_k, _, _ = k.shape
772779
# assert d in {16, 32, 64, 128}
773780
assert dim <= 128
774-
seqlen_q_rounded = ceil(seqlen_q / 128) * 128
781+
seqlen_q_rounded = round_up_multiple(seqlen_q, TRITON_BLOCK_SIZE)
775782

776783
assert lse.shape == (batch, nheads, seqlen_q_rounded)
777-
assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
778-
assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
784+
assert all([is_contiguous(t) for t in (q, k, v, o, dq, dk, dv)])
785+
779786
softmax_scale = dim ** -0.5
780787
# dq_accum = torch.zeros_like(q, dtype=torch.float32)
781788
dq_accum = torch.empty_like(q, dtype=torch.float32)
@@ -786,6 +793,7 @@ def flash_attn_backward(
786793

787794
delta = torch.empty_like(lse)
788795
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK"]), batch * nheads)
796+
789797
_bwd_preprocess_do_o_dot[grid](
790798
o,
791799
do,

0 commit comments

Comments
 (0)