Skip to content

Commit db09298

Browse files
[06-fused-attention] Temporarily modify the source code to get back performance
The commit should be reverted when the new implementation of transpose can be handled efficiently. Signed-off-by: Whitney Tsang <[email protected]>
1 parent 09461ac commit db09298

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

python/tutorials/06-fused-attention.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ def is_blackwell():
4040
return is_cuda() and torch.cuda.get_device_capability()[0] == 10
4141

4242

43+
# FIXME: Revert temporary source code modification done in last commit of PR #4399.
44+
45+
4346
@triton.jit
4447
def _attn_fwd_inner(acc, l_i, m_i, q, #
4548
desc_k, desc_v, #
@@ -65,7 +68,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
6568
for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize):
6669
start_n = tl.multiple_of(start_n, BLOCK_N)
6770
# -- compute qk ----
68-
k = desc_k.load([offsetk_y, 0]).T
71+
k = desc_k.load([0, offsetk_y])
6972
qk = tl.dot(q, k)
7073
if STAGE == 2:
7174
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
@@ -83,7 +86,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
8386
acc = acc * alpha[:, None]
8487
# prepare p and v for the dot
8588
if dtype == tl.float8e5:
86-
v = desc_v.load([0, offsetv_y]).T
89+
v = desc_v.load([offsetv_y, 0])
8790
else:
8891
v = desc_v.load([offsetv_y, 0])
8992
p = p.to(dtype)
@@ -176,13 +179,13 @@ def _attn_fwd(sm_scale, M, #
176179
desc_q = _maybe_make_tensor_desc(desc_q, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
177180
block_shape=[BLOCK_M, HEAD_DIM])
178181
if FP8_OUTPUT:
179-
desc_v = _maybe_make_tensor_desc(desc_v, shape=[HEAD_DIM, y_dim], strides=[N_CTX, 1],
180-
block_shape=[HEAD_DIM, BLOCK_N])
182+
desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim, HEAD_DIM], strides=[1, N_CTX],
183+
block_shape=[BLOCK_N, HEAD_DIM])
181184
else:
182185
desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
183186
block_shape=[BLOCK_N, HEAD_DIM])
184-
desc_k = _maybe_make_tensor_desc(desc_k, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
185-
block_shape=[BLOCK_N, HEAD_DIM])
187+
desc_k = _maybe_make_tensor_desc(desc_k, shape=[HEAD_DIM, y_dim], strides=[1, HEAD_DIM],
188+
block_shape=[HEAD_DIM, BLOCK_N])
186189
desc_o = _maybe_make_tensor_desc(desc_o, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
187190
block_shape=[BLOCK_M, HEAD_DIM])
188191

0 commit comments

Comments
 (0)