Skip to content

Commit 09461ac

Browse files
[Tutorial] Fix 06-fused-attention.py of FP8 dtype (#7043)
When the provider is `fp8`, `v` is permuted like below, and the new stride is `(H*N_CTX*HEAD_DIM, N_CTX*HEAD_DIM, 1, N_CTX)`. ``` if mode == "fwd" and "fp8" in provider: v = v.permute(0, 1, 3, 2).contiguous() v = v.permute(0, 1, 3, 2) ``` This PR fixes the FP8 dtype handling in the fused-attention kernel by separating `k` and `v` offset calculations and updating related configuration details. Key changes include: - Renaming and separating offset variables for `k` and `v` computations. - Adjusting offset calculation for FP8 dtype and updating the tensor descriptor creation. - Expanding configuration options for BLOCK_N and refining device-specific configuration conditions. Signed-off-by: Whitney Tsang <[email protected]>
1 parent 5efc32b commit 09461ac

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

python/tutorials/06-fused-attention.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,16 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
5656
# causal = False
5757
else:
5858
lo, hi = 0, N_CTX
59-
offsetkv_y = offset_y + lo
59+
offsetk_y = offset_y + lo
60+
if dtype == tl.float8e5:
61+
offsetv_y = offset_y * HEAD_DIM + lo
62+
else:
63+
offsetv_y = offset_y + lo
6064
# loop over k, v and update accumulator
6165
for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize):
6266
start_n = tl.multiple_of(start_n, BLOCK_N)
6367
# -- compute qk ----
64-
k = desc_k.load([offsetkv_y, 0]).T
68+
k = desc_k.load([offsetk_y, 0]).T
6569
qk = tl.dot(q, k)
6670
if STAGE == 2:
6771
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
@@ -78,15 +82,19 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
7882
# -- update output accumulator --
7983
acc = acc * alpha[:, None]
8084
# prepare p and v for the dot
81-
v = desc_v.load([offsetkv_y, 0])
85+
if dtype == tl.float8e5:
86+
v = desc_v.load([0, offsetv_y]).T
87+
else:
88+
v = desc_v.load([offsetv_y, 0])
8289
p = p.to(dtype)
8390
# note that this non transposed v for FP8 is only supported on Blackwell
8491
acc = tl.dot(p, v, acc)
8592
# update m_i and l_i
8693
# place this at the end of the loop to reduce register pressure
8794
l_i = l_i * alpha + l_ij
8895
m_i = m_ij
89-
offsetkv_y += BLOCK_N
96+
offsetk_y += BLOCK_N
97+
offsetv_y += BLOCK_N
9098
return acc, l_i, m_i
9199

92100

@@ -112,7 +120,7 @@ def _host_descriptor_pre_hook(nargs):
112120
configs = [
113121
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w, pre_hook=_host_descriptor_pre_hook) \
114122
for BM in [64, 128]\
115-
for BN in [64, 128]\
123+
for BN in [32, 64, 128]\
116124
for s in NUM_STAGES_OPTIONS \
117125
for w in [4, 8]\
118126
]
@@ -167,8 +175,12 @@ def _attn_fwd(sm_scale, M, #
167175
y_dim = Z * H * N_CTX
168176
desc_q = _maybe_make_tensor_desc(desc_q, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
169177
block_shape=[BLOCK_M, HEAD_DIM])
170-
desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
171-
block_shape=[BLOCK_N, HEAD_DIM])
178+
if FP8_OUTPUT:
179+
desc_v = _maybe_make_tensor_desc(desc_v, shape=[HEAD_DIM, y_dim], strides=[N_CTX, 1],
180+
block_shape=[HEAD_DIM, BLOCK_N])
181+
else:
182+
desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
183+
block_shape=[BLOCK_N, HEAD_DIM])
172184
desc_k = _maybe_make_tensor_desc(desc_k, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
173185
block_shape=[BLOCK_N, HEAD_DIM])
174186
desc_o = _maybe_make_tensor_desc(desc_o, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],

0 commit comments

Comments
 (0)