@@ -18,6 +18,9 @@ def exists(v):
1818def default (val , d ):
1919 return val if exists (val ) else d
2020
21+ def round_up_multiple (n , mult ):
22+ return ceil (n / mult ) * mult
23+
2124def is_contiguous (x : Tensor ):
2225 return x .stride (- 1 ) == 1
2326
@@ -42,6 +45,10 @@ def is_contiguous(x: Tensor):
4245import triton .language as tl
4346from triton .language .extra import libdevice
4447
48+ # constants
49+
50+ TRITON_BLOCK_SIZE = 128
51+
4552# kernels
4653
4754@triton .heuristics (
@@ -256,7 +263,7 @@ def flash_attn_forward(
256263
257264 softmax_scale = dim ** - 0.5
258265
259- seqlen_q_rounded = ceil (seqlen_q / 128 ) * 128
266+ seqlen_q_rounded = round_up_multiple (seqlen_q , TRITON_BLOCK_SIZE )
260267
261268 lse = torch .empty ((batch , nheads , seqlen_q_rounded ), device = q .device , dtype = torch .float32 )
262269
@@ -764,18 +771,18 @@ def flash_attn_backward(
764771 block_size = 128
765772):
766773 # Make sure that the last dimension is contiguous
767- if do . stride ( - 1 ) != 1 :
774+ if not is_contiguous ( do ) :
768775 do = do .contiguous ()
769776
770777 batch , seqlen_q , nheads , dim = q .shape
771778 _ , seqlen_k , _ , _ = k .shape
772779 # assert d in {16, 32, 64, 128}
773780 assert dim <= 128
774- seqlen_q_rounded = ceil (seqlen_q / 128 ) * 128
781+ seqlen_q_rounded = round_up_multiple (seqlen_q , TRITON_BLOCK_SIZE )
775782
776783 assert lse .shape == (batch , nheads , seqlen_q_rounded )
777- assert q . stride ( - 1 ) == k . stride ( - 1 ) == v . stride ( - 1 ) == o . stride ( - 1 ) == 1
778- assert dq . stride ( - 1 ) == dk . stride ( - 1 ) == dv . stride ( - 1 ) == 1
784+ assert all ([ is_contiguous ( t ) for t in ( q , k , v , o , dq , dk , dv )])
785+
779786 softmax_scale = dim ** - 0.5
780787 # dq_accum = torch.zeros_like(q, dtype=torch.float32)
781788 dq_accum = torch .empty_like (q , dtype = torch .float32 )
@@ -786,6 +793,7 @@ def flash_attn_backward(
786793
787794 delta = torch .empty_like (lse )
788795 grid = lambda META : (triton .cdiv (seqlen_q , META ["BLOCK" ]), batch * nheads )
796+
789797 _bwd_preprocess_do_o_dot [grid ](
790798 o ,
791799 do ,
0 commit comments