Skip to content

Commit 7b618d9

Browse files
authored
[WS][Hopper] Enable Warpspec for FA with non-causal and on-device TMA (#7658)
Hopper WS doesn't support data partitioning with on-host TMA. Also subtiling causes issues with our data partitioning. We can try explicit data partitioning in the kernel.
1 parent 07da911 commit 7b618d9

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

python/tutorials/06-fused-attention.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,17 @@ def is_blackwell():
4040
return is_cuda() and torch.cuda.get_device_capability()[0] == 10
4141

4242

43+
def is_hopper():
44+
return is_cuda() and torch.cuda.get_device_capability()[0] == 9
45+
46+
4347
@triton.jit
4448
def _attn_fwd_inner(acc, l_i, m_i, q, #
4549
desc_k, desc_v, #
4650
offset_y, dtype: tl.constexpr, start_m, qk_scale, #
4751
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, #
4852
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
49-
N_CTX: tl.constexpr, warp_specialize: tl.constexpr):
53+
N_CTX: tl.constexpr, warp_specialize: tl.constexpr, SUBTILE_EPILOGUE: tl.constexpr):
5054
# range of values handled by this stage
5155
if STAGE == 1:
5256
lo, hi = 0, start_m * BLOCK_M
@@ -80,7 +84,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
8084
alpha = tl.math.exp2(m_i - m_ij)
8185
l_ij = tl.sum(p, 1)
8286
# -- update output accumulator --
83-
if warp_specialize and BLOCK_M == 128 and HEAD_DIM == 128:
87+
if SUBTILE_EPILOGUE:
8488
BM: tl.constexpr = acc.shape[0]
8589
BN: tl.constexpr = acc.shape[1]
8690
acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
@@ -175,6 +179,7 @@ def _attn_fwd(sm_scale, M, #
175179
FP8_OUTPUT: tl.constexpr, #
176180
STAGE: tl.constexpr, #
177181
warp_specialize: tl.constexpr, #
182+
SUBTILE_EPILOGUE: tl.constexpr, #
178183
):
179184
dtype = tl.float8e5 if FP8_OUTPUT else tl.float16
180185
tl.static_assert(BLOCK_N <= HEAD_DIM)
@@ -220,15 +225,15 @@ def _attn_fwd(sm_scale, M, #
220225
offset_y, dtype, start_m, qk_scale, #
221226
BLOCK_M, HEAD_DIM, BLOCK_N, #
222227
4 - STAGE, offs_m, offs_n, N_CTX, #
223-
warp_specialize)
228+
warp_specialize, SUBTILE_EPILOGUE)
224229
# stage 2: on-band
225230
if STAGE & 2:
226231
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, #
227232
desc_k, desc_v, #
228233
offset_y, dtype, start_m, qk_scale, #
229234
BLOCK_M, HEAD_DIM, BLOCK_N, #
230235
2, offs_m, offs_n, N_CTX, #
231-
warp_specialize)
236+
warp_specialize, SUBTILE_EPILOGUE)
232237
# epilogue
233238
m_i += tl.math.log2(l_i)
234239
acc = acc / l_i[:, None]
@@ -504,7 +509,8 @@ def forward(ctx, q, k, v, causal, sm_scale, warp_specialize=True):
504509
extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True}
505510

506511
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
507-
if supports_host_descriptor():
512+
# Use device_descriptor for Hopper + warpspec.
513+
if supports_host_descriptor() and not (is_hopper() and warp_specialize):
508514
# Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor
509515
y_dim = q.shape[0] * q.shape[1] * q.shape[2]
510516

@@ -533,7 +539,8 @@ def grid(META):
533539
return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
534540

535541
ctx.grid = grid
536-
if is_cuda() and warp_specialize:
542+
SUBTILE_EPILOGUE = False if is_hopper() and warp_specialize else True
543+
if is_blackwell() and warp_specialize:
537544
if HEAD_DIM_K == 128 and q.dtype == torch.float16:
538545
extra_kern_args["maxnreg"] = 168
539546
else:
@@ -547,7 +554,7 @@ def grid(META):
547554
FP8_OUTPUT=q.dtype == torch.float8_e5m2, #
548555
STAGE=stage, #
549556
warp_specialize=warp_specialize, #
550-
**extra_kern_args)
557+
SUBTILE_EPILOGUE=SUBTILE_EPILOGUE, **extra_kern_args)
551558

552559
ctx.save_for_backward(q, k, v, o, M)
553560
ctx.sm_scale = sm_scale
@@ -684,7 +691,9 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, mode, provider, dtyp
684691
for HEAD_DIM in [64, 128]:
685692
for mode in ["fwd", "bwd"]:
686693
for causal in [True, False]:
687-
for warp_specialize in [False, True] if is_blackwell() else [False]:
694+
# Enable warpspec for causal fwd on Hopper
695+
for warp_specialize in [False, True] if (is_blackwell() or
696+
(is_hopper() and mode == "fwd" and not causal)) else [False]:
688697
configs.append(
689698
triton.testing.Benchmark(
690699
x_names=["N_CTX"],

0 commit comments

Comments
 (0)