Skip to content

Commit 2a9e85e

Browse files
authored
[Tutorial] Fix subtile flags for blackwell (#7679)
1 parent 7b618d9 commit 2a9e85e

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

python/tutorials/06-fused-attention.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
5050
offset_y, dtype: tl.constexpr, start_m, qk_scale, #
5151
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, #
5252
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
53-
N_CTX: tl.constexpr, warp_specialize: tl.constexpr, SUBTILE_EPILOGUE: tl.constexpr):
53+
N_CTX: tl.constexpr, warp_specialize: tl.constexpr, IS_HOPPER: tl.constexpr):
5454
# range of values handled by this stage
5555
if STAGE == 1:
5656
lo, hi = 0, start_m * BLOCK_M
@@ -84,7 +84,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
8484
alpha = tl.math.exp2(m_i - m_ij)
8585
l_ij = tl.sum(p, 1)
8686
# -- update output accumulator --
87-
if SUBTILE_EPILOGUE:
87+
if not IS_HOPPER and warp_specialize and BLOCK_M == 128 and HEAD_DIM == 128:
8888
BM: tl.constexpr = acc.shape[0]
8989
BN: tl.constexpr = acc.shape[1]
9090
acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
@@ -179,7 +179,7 @@ def _attn_fwd(sm_scale, M, #
179179
FP8_OUTPUT: tl.constexpr, #
180180
STAGE: tl.constexpr, #
181181
warp_specialize: tl.constexpr, #
182-
SUBTILE_EPILOGUE: tl.constexpr, #
182+
IS_HOPPER: tl.constexpr, #
183183
):
184184
dtype = tl.float8e5 if FP8_OUTPUT else tl.float16
185185
tl.static_assert(BLOCK_N <= HEAD_DIM)
@@ -225,15 +225,15 @@ def _attn_fwd(sm_scale, M, #
225225
offset_y, dtype, start_m, qk_scale, #
226226
BLOCK_M, HEAD_DIM, BLOCK_N, #
227227
4 - STAGE, offs_m, offs_n, N_CTX, #
228-
warp_specialize, SUBTILE_EPILOGUE)
228+
warp_specialize, IS_HOPPER)
229229
# stage 2: on-band
230230
if STAGE & 2:
231231
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, #
232232
desc_k, desc_v, #
233233
offset_y, dtype, start_m, qk_scale, #
234234
BLOCK_M, HEAD_DIM, BLOCK_N, #
235235
2, offs_m, offs_n, N_CTX, #
236-
warp_specialize, SUBTILE_EPILOGUE)
236+
warp_specialize, IS_HOPPER)
237237
# epilogue
238238
m_i += tl.math.log2(l_i)
239239
acc = acc / l_i[:, None]
@@ -539,7 +539,6 @@ def grid(META):
539539
return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
540540

541541
ctx.grid = grid
542-
SUBTILE_EPILOGUE = False if is_hopper() and warp_specialize else True
543542
if is_blackwell() and warp_specialize:
544543
if HEAD_DIM_K == 128 and q.dtype == torch.float16:
545544
extra_kern_args["maxnreg"] = 168
@@ -554,7 +553,8 @@ def grid(META):
554553
FP8_OUTPUT=q.dtype == torch.float8_e5m2, #
555554
STAGE=stage, #
556555
warp_specialize=warp_specialize, #
557-
SUBTILE_EPILOGUE=SUBTILE_EPILOGUE, **extra_kern_args)
556+
IS_HOPPER=is_hopper(), #
557+
**extra_kern_args)
558558

559559
ctx.save_for_backward(q, k, v, o, M)
560560
ctx.sm_scale = sm_scale
@@ -692,8 +692,8 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, mode, provider, dtyp
692692
for mode in ["fwd", "bwd"]:
693693
for causal in [True, False]:
694694
# Enable warpspec for causal fwd on Hopper
695-
for warp_specialize in [False, True] if (is_blackwell() or
696-
(is_hopper() and mode == "fwd" and not causal)) else [False]:
695+
enable_ws = mode == "fwd" and (is_blackwell() or (is_hopper() and not causal))
696+
for warp_specialize in [False, True] if enable_ws else [False]:
697697
configs.append(
698698
triton.testing.Benchmark(
699699
x_names=["N_CTX"],

0 commit comments

Comments
 (0)