@@ -65,7 +65,6 @@ def forward_kernel(
6565 kv_block_indices ,
6666 kv_block_mask ,
6767 Out ,
68- M ,
6968 Lse ,
7069 softmax_scale ,
7170 stride_qb ,
@@ -118,8 +117,6 @@ def forward_kernel(
118117
119118 # maximum
120119
121- m_ptrs = M + off_hb * seqlen_q_rounded + offs_m
122-
123120 m_i = tl .zeros ([BLOCK ], dtype = tl .float32 ) - float ("inf" )
124121
125122 # lse
@@ -189,7 +186,7 @@ def forward_kernel(
189186 l_ij = tl .sum (p , 1 )
190187
191188 acc_o_scale = tl .exp (m_i - m_ij )
192- acc_o = acc_o * acc_o_scale [:, None ]
189+ acc_o *= acc_o_scale [:, None ]
193190
194191 if EVEN_N & EVEN_M :
195192 if EVEN_HEADDIM :
@@ -302,12 +299,7 @@ def forward_kernel(
302299 # normalize accumulated out
303300
304301 acc_o_scale = tl .exp (m_i - lse_i )
305- acc_o = acc_o * acc_o_scale [:, None ]
306-
307- # offsets
308-
309- start_m = tl .program_id (0 )
310- offs_m = start_m * BLOCK + tl .arange (0 , BLOCK )
302+ acc_o *= acc_o_scale [:, None ]
311303
312304 # write back lse
313305
@@ -357,8 +349,6 @@ def flash_attn_forward(
357349
358350 lse = torch .empty ((batch , nheads , seqlen_q_rounded ), device = device , dtype = torch .float32 )
359351
360- m = torch .empty ((batch , nheads , seqlen_q_rounded ), device = device , dtype = torch .float32 )
361-
362352 o = torch .empty_like (q )
363353
364354 BLOCK_HEADDIM = max (triton .next_power_of_2 (dim ), 16 )
@@ -372,7 +362,6 @@ def flash_attn_forward(
372362 kv_block_indices ,
373363 kv_block_mask ,
374364 o ,
375- m ,
376365 lse ,
377366 softmax_scale ,
378367 q .stride (0 ),
0 commit comments