@@ -240,23 +240,21 @@ def flash_attn_forward(
240240 q ,
241241 k ,
242242 v ,
243- softmax_scale = None ,
244- remove_padding = False ,
245243 block_size = 128
246244):
247245 q , k , v = [x if is_contiguous (x ) else x .contiguous () for x in (q , k , v )]
248246
249- batch , seqlen_q , nheads , d = q .shape
247+ batch , seqlen_q , nheads , dim = q .shape
250248 _ , seqlen_k , _ , _ = k .shape
251249
252- assert k .shape == (batch , seqlen_k , nheads , d )
253- assert v .shape == (batch , seqlen_k , nheads , d )
254- assert d <= 128 , "FlashAttention only support head dimensions up to 128"
250+ assert k .shape == (batch , seqlen_k , nheads , dim )
251+ assert v .shape == (batch , seqlen_k , nheads , dim )
252+ assert dim <= 128 , "only support head dimensions up to 128"
255253 assert q .dtype == k .dtype == v .dtype , "All tensors must have the same type"
256254 assert q .dtype in [torch .float16 , torch .bfloat16 ], "Only support fp16 and bf16"
257255 assert q .is_cuda and k .is_cuda and v .is_cuda
258256
259- softmax_scale = default ( softmax_scale , d ** - 0.5 )
257+ softmax_scale = dim ** - 0.5
260258
261259 seqlen_q_rounded = ceil (seqlen_q / 128 ) * 128
262260
@@ -266,8 +264,8 @@ def flash_attn_forward(
266264
267265 o = torch .empty_like (q )
268266
269- BLOCK_HEADDIM = max (triton .next_power_of_2 (d ), 16 )
270- num_warps = 4 if d <= 64 else 8
267+ BLOCK_HEADDIM = max (triton .next_power_of_2 (dim ), 16 )
268+ num_warps = 4 if dim <= 64 else 8
271269 grid = lambda META : (triton .cdiv (seqlen_q , META ["BLOCK" ]), batch * nheads )
272270
273271 _fwd_kernel [grid ](
@@ -294,7 +292,7 @@ def flash_attn_forward(
294292 seqlen_q ,
295293 seqlen_k ,
296294 seqlen_q_rounded ,
297- d ,
295+ dim ,
298296 seqlen_q // 32 ,
299297 seqlen_k // 32 ,
300298 BLOCK_HEADDIM ,
@@ -303,9 +301,6 @@ def flash_attn_forward(
303301 num_stages = 1 ,
304302 )
305303
306- if remove_padding :
307- lse = lse [..., :seqlen_q ]
308-
309304 return o , lse
310305
311306@triton .jit
@@ -352,7 +347,6 @@ def _bwd_preprocess_do_o_dot(
352347 # write-back
353348 tl .store (Delta + off_hb * seqlen_q_rounded + offs_m , delta )
354349
355-
356350@triton .jit
357351def _bwd_store_dk_dv (
358352 dk_ptrs ,
@@ -632,10 +626,6 @@ def _bwd_kernel_one_col_block(
632626 EVEN_HEADDIM = EVEN_HEADDIM ,
633627 )
634628
635-
636- def init_to_zero (name ):
637- return lambda nargs : nargs [name ].zero_ ()
638-
639629@triton .jit
640630def _bwd_kernel (
641631 Q ,
@@ -771,29 +761,28 @@ def flash_attn_backward(
771761 dq ,
772762 dk ,
773763 dv ,
774- softmax_scale = None ,
775764 block_size = 128
776765):
777766 # Make sure that the last dimension is contiguous
778767 if do .stride (- 1 ) != 1 :
779768 do = do .contiguous ()
780769
781- batch , seqlen_q , nheads , d = q .shape
770+ batch , seqlen_q , nheads , dim = q .shape
782771 _ , seqlen_k , _ , _ = k .shape
783772 # assert d in {16, 32, 64, 128}
784- assert d <= 128
773+ assert dim <= 128
785774 seqlen_q_rounded = ceil (seqlen_q / 128 ) * 128
786775
787776 assert lse .shape == (batch , nheads , seqlen_q_rounded )
788777 assert q .stride (- 1 ) == k .stride (- 1 ) == v .stride (- 1 ) == o .stride (- 1 ) == 1
789778 assert dq .stride (- 1 ) == dk .stride (- 1 ) == dv .stride (- 1 ) == 1
790- softmax_scale = softmax_scale or 1.0 / math . sqrt ( d )
779+ softmax_scale = dim ** - 0.5
791780 # dq_accum = torch.zeros_like(q, dtype=torch.float32)
792781 dq_accum = torch .empty_like (q , dtype = torch .float32 )
793782
794783 # delta = torch.zeros_like(lse)
795784
796- BLOCK_HEADDIM = max (triton .next_power_of_2 (d ), 16 )
785+ BLOCK_HEADDIM = max (triton .next_power_of_2 (dim ), 16 )
797786
798787 delta = torch .empty_like (lse )
799788 grid = lambda META : (triton .cdiv (seqlen_q , META ["BLOCK" ]), batch * nheads )
@@ -810,7 +799,7 @@ def flash_attn_backward(
810799 nheads ,
811800 seqlen_q ,
812801 seqlen_q_rounded ,
813- d ,
802+ dim ,
814803 BLOCK = block_size ,
815804 BLOCK_HEADDIM = BLOCK_HEADDIM ,
816805 )
@@ -858,7 +847,7 @@ def flash_attn_backward(
858847 seqlen_q ,
859848 seqlen_k ,
860849 seqlen_q_rounded ,
861- d ,
850+ dim ,
862851 seqlen_q // 32 ,
863852 seqlen_k // 32 , # key for triton cache (limit number of compilations)
864853 # Can't use kwargs here because triton autotune expects key to be args, not kwargs
@@ -868,7 +857,7 @@ def flash_attn_backward(
868857 SEQUENCE_PARALLEL = False ,
869858 EVEN_M = (seqlen_q % block_size ) == 0 ,
870859 EVEN_N = (seqlen_k % block_size ) == 0 ,
871- EVEN_HEADDIM = BLOCK_HEADDIM == d
860+ EVEN_HEADDIM = BLOCK_HEADDIM == dim
872861 # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
873862 # num_warps=num_warps,
874863 # num_stages=1,
0 commit comments