Skip to content

Commit 1cdfaf0

Browse files
[FlashAttention] Sync from upstream tensor desc implementation (part 2) (#4470)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 4e81944 commit 1cdfaf0

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_tensor_desc_benchmark.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
2929
offsetk_y = offset_y + lo
3030
offsetv_y = offset_y + lo
3131
# loop over k, v and update accumulator
32-
for start_n in range(lo, hi, BLOCK_N):
32+
for start_n in tl.range(lo, hi, BLOCK_N):
3333
start_n = tl.multiple_of(start_n, BLOCK_N)
3434
# -- compute qk ----
3535
k = desc_k.load([0, offsetk_y])
@@ -44,16 +44,17 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
4444
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
4545
qk = qk * qk_scale - m_ij[:, None]
4646
p = tl.math.exp2(qk)
47-
l_ij = tl.sum(p, 1)
4847
# -- compute correction factor
4948
alpha = tl.math.exp2(m_i - m_ij)
49+
l_ij = tl.sum(p, 1)
5050
l_i = l_i * alpha + l_ij
5151
# -- update output accumulator --
5252
acc = acc * alpha[:, None]
5353
# prepare p and v for the dot
5454
v = desc_v.load([offsetv_y, 0])
55+
p = p.to(dtype)
5556
# note that this non transposed v for FP8 is only supported on Blackwell
56-
acc += tl.dot(p.to(tl.float16), v)
57+
acc = tl.dot(p, v, acc)
5758
# update m_i and l_i
5859
# place this at the end of the loop to reduce register pressure
5960
m_i = m_ij

0 commit comments

Comments
 (0)