@@ -286,31 +286,32 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
286286 qk = tl .where (mask , qk , float ("-inf" ))
287287 # -- compute qk ----
288288 if INT8_GEMM :
289- qk += (((( tl .dot (q , k ).to (tl .float32 ) * q_descale )) * k_descale ) * QK_SCALE )
289+ qk += ((tl .dot (q , k ).to (tl .float32 ) * q_descale )) * k_descale
290290 else :
291291 if INT8_KV :
292292 k = (k * k_descale ).to (q .type .element_ty )
293- qk += ( tl .dot (q , k ) * QK_SCALE )
293+ qk += tl .dot (q , k )
294294
295295 if bias_ptrs is not None :
296296 bias_offs_n = start_n + tl .arange (0 , BLOCK_N ) if MASK_STEPS else None
297297 bias = load_fn (bias_ptrs , OFFS_M , bias_offs_n , actual_seqlen_q , actual_seqlen_k )
298298 # While bias is added after multiplying qk with sm_scale,
299299 # our optimization to use 2^x instead of e^x results in an additional
300300 # scale factor of log2(e) which we must also multiply the bias with.
301- qk += (bias * 1.44269504089 )
301+ qk += (bias * 1.44269504089 / QK_SCALE )
302302
303303 if alibi_slope is not None :
304304 # Compute the global position of each token within the sequence
305305 global_m_positions = start_m * BLOCK_M + tl .arange (0 , BLOCK_M )
306306 global_n_positions = start_n + tl .arange (0 , BLOCK_N )
307307 alibi_block = compute_alibi_block (alibi_slope , actual_seqlen_q , actual_seqlen_k , global_m_positions ,
308308 global_n_positions )
309- qk += (alibi_block * 1.44269504089 ) # scale factor of log2(e)
309+ qk += (alibi_block * 1.44269504089 / QK_SCALE ) # scale factor of log2(e)
310310
311311 # softmax
312312 m_ij = tl .maximum (m_i , tl .max (qk , 1 ))
313- qk = qk - m_ij [:, None ]
313+ m_ij_scaled = m_ij * QK_SCALE
314+ qk = qk * QK_SCALE - m_ij_scaled [:, None ]
314315 p = tl .math .exp2 (qk )
315316
316317 # CAVEAT: Must update l_ij before applying dropout
@@ -324,7 +325,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
324325 elif RETURN_ENCODED_SOFTMAX :
325326 tl .store (encoded_sm_ptrs , p .to (encoded_sm_ptrs .type .element_ty ))
326327 # -- update output accumulator --
327- alpha = tl .math .exp2 (m_i - m_ij )
328+ alpha = tl .math .exp2 (m_i * QK_SCALE - m_ij_scaled )
328329 acc = acc * alpha [:, None ]
329330 if not PRE_LOAD_V :
330331 v = load_fn (v_ptrs , k_offs_n , k_offs_k , actual_seqlen_k , ACTUAL_BLOCK_DMODEL )
0 commit comments