Skip to content

Commit 1468103

Browse files
authored
Compute max before mul QK_SCALE to fold sub into fma (#781)
1 parent 34ee120 commit 1468103

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

python/perf-kernels/flash-attention.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -286,31 +286,32 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
286286
qk = tl.where(mask, qk, float("-inf"))
287287
# -- compute qk ----
288288
if INT8_GEMM:
289-
qk += ((((tl.dot(q, k).to(tl.float32) * q_descale)) * k_descale) * QK_SCALE)
289+
qk += ((tl.dot(q, k).to(tl.float32) * q_descale)) * k_descale
290290
else:
291291
if INT8_KV:
292292
k = (k * k_descale).to(q.type.element_ty)
293-
qk += (tl.dot(q, k) * QK_SCALE)
293+
qk += tl.dot(q, k)
294294

295295
if bias_ptrs is not None:
296296
bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None
297297
bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k)
298298
# While bias is added after multiplying qk with sm_scale,
299299
# our optimization to use 2^x instead of e^x results in an additional
300300
# scale factor of log2(e) which we must also multiply the bias with.
301-
qk += (bias * 1.44269504089)
301+
qk += (bias * 1.44269504089 / QK_SCALE)
302302

303303
if alibi_slope is not None:
304304
# Compute the global position of each token within the sequence
305305
global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
306306
global_n_positions = start_n + tl.arange(0, BLOCK_N)
307307
alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions,
308308
global_n_positions)
309-
qk += (alibi_block * 1.44269504089) # scale factor of log2(e)
309+
qk += (alibi_block * 1.44269504089 / QK_SCALE) # scale factor of log2(e)
310310

311311
# softmax
312312
m_ij = tl.maximum(m_i, tl.max(qk, 1))
313-
qk = qk - m_ij[:, None]
313+
m_ij_scaled = m_ij * QK_SCALE
314+
qk = qk * QK_SCALE - m_ij_scaled[:, None]
314315
p = tl.math.exp2(qk)
315316

316317
# CAVEAT: Must update l_ij before applying dropout
@@ -324,7 +325,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
324325
elif RETURN_ENCODED_SOFTMAX:
325326
tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty))
326327
# -- update output accumulator --
327-
alpha = tl.math.exp2(m_i - m_ij)
328+
alpha = tl.math.exp2(m_i * QK_SCALE - m_ij_scaled)
328329
acc = acc * alpha[:, None]
329330
if not PRE_LOAD_V:
330331
v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL)

0 commit comments

Comments
 (0)