Skip to content

Commit 33a25b4

Browse files
[FlashAttention] Sync from upstream tensor desc implementation (part 3) (#4520)
No geomean regression <img width="619" height="318" alt="Screenshot 2025-08-07 230608" src="https://github.com/user-attachments/assets/c1f2ac2e-6c47-4238-9086-52dcb337ceae" /> Signed-off-by: Whitney Tsang <[email protected]>
1 parent 40ea35f commit 33a25b4

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_tensor_desc_benchmark.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
3333
start_n = tl.multiple_of(start_n, BLOCK_N)
3434
# -- compute qk ----
3535
k = desc_k.load([0, offsetk_y])
36-
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
37-
qk += tl.dot(q, k)
36+
qk = tl.dot(q, k)
3837
if STAGE == 2:
3938
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
4039
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
@@ -47,7 +46,6 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
4746
# -- compute correction factor
4847
alpha = tl.math.exp2(m_i - m_ij)
4948
l_ij = tl.sum(p, 1)
50-
l_i = l_i * alpha + l_ij
5149
# -- update output accumulator --
5250
acc = acc * alpha[:, None]
5351
# prepare p and v for the dot
@@ -57,6 +55,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
5755
acc = tl.dot(p, v, acc)
5856
# update m_i and l_i
5957
# place this at the end of the loop to reduce register pressure
58+
l_i = l_i * alpha + l_ij
6059
m_i = m_ij
6160
offsetk_y += BLOCK_N
6261
offsetv_y += BLOCK_N

0 commit comments

Comments
 (0)