@@ -240,9 +240,6 @@ def flash_attn_forward(
240240 q ,
241241 k ,
242242 v ,
243- o = None ,
244- m = None ,
245- lse = None ,
246243 softmax_scale = None ,
247244 remove_padding = False ,
248245 block_size = 128
@@ -263,16 +260,11 @@ def flash_attn_forward(
263260
264261 seqlen_q_rounded = ceil (seqlen_q / 128 ) * 128
265262
266- if not exists (lse ):
267- max_neg_value = - torch .finfo (torch .float32 ).max
268- lse = torch .empty ((batch , nheads , seqlen_q_rounded ), device = q .device , dtype = torch .float32 )
263+ lse = torch .empty ((batch , nheads , seqlen_q_rounded ), device = q .device , dtype = torch .float32 )
269264
270- if not exists (m ):
271- max_neg_value = - torch .finfo (torch .float32 ).max
272- m = torch .empty ((batch , nheads , seqlen_q_rounded ), device = q .device , dtype = torch .float32 )
265+ m = torch .empty ((batch , nheads , seqlen_q_rounded ), device = q .device , dtype = torch .float32 )
273266
274- if not exists (o ):
275- o = torch .empty_like (q )
267+ o = torch .empty_like (q )
276268
277269 BLOCK_HEADDIM = max (triton .next_power_of_2 (d ), 16 )
278270 num_warps = 4 if d <= 64 else 8
@@ -779,7 +771,6 @@ def flash_attn_backward(
779771 dq ,
780772 dk ,
781773 dv ,
782- delta = None ,
783774 softmax_scale = None ,
784775 block_size = 128
785776):
@@ -804,26 +795,25 @@ def flash_attn_backward(
804795
805796 BLOCK_HEADDIM = max (triton .next_power_of_2 (d ), 16 )
806797
807- if not exists (delta ):
808- delta = torch .empty_like (lse )
809- grid = lambda META : (triton .cdiv (seqlen_q , META ["BLOCK" ]), batch * nheads )
810- _bwd_preprocess_do_o_dot [grid ](
811- o ,
812- do ,
813- delta ,
814- o .stride (0 ),
815- o .stride (2 ),
816- o .stride (1 ),
817- do .stride (0 ),
818- do .stride (2 ),
819- do .stride (1 ),
820- nheads ,
821- seqlen_q ,
822- seqlen_q_rounded ,
823- d ,
824- BLOCK = block_size ,
825- BLOCK_HEADDIM = BLOCK_HEADDIM ,
826- )
798+ delta = torch .empty_like (lse )
799+ grid = lambda META : (triton .cdiv (seqlen_q , META ["BLOCK" ]), batch * nheads )
800+ _bwd_preprocess_do_o_dot [grid ](
801+ o ,
802+ do ,
803+ delta ,
804+ o .stride (0 ),
805+ o .stride (2 ),
806+ o .stride (1 ),
807+ do .stride (0 ),
808+ do .stride (2 ),
809+ do .stride (1 ),
810+ nheads ,
811+ seqlen_q ,
812+ seqlen_q_rounded ,
813+ d ,
814+ BLOCK = block_size ,
815+ BLOCK_HEADDIM = BLOCK_HEADDIM ,
816+ )
827817
828818 # BLOCK_M = 128
829819 # BLOCK_N = 64
0 commit comments