Skip to content

Commit d748303

Browse files
authored
[Tutorial] Improve dhead=128 ws performance for attention (#7195)
Pushes DHEAD=128 perf for ws attention up to 850 TFLOPS. SWP has it beat at almost 970 TFLOPS though! This helps as a baseline against Gluon.
1 parent 7a342f2 commit d748303

File tree

1 file changed

+41
-31
lines changed

1 file changed

+41
-31
lines changed

python/tutorials/06-fused-attention.py

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,15 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
7676
alpha = tl.math.exp2(m_i - m_ij)
7777
l_ij = tl.sum(p, 1)
7878
# -- update output accumulator --
79-
acc = acc * alpha[:, None]
79+
if warp_specialize and BLOCK_M == 128 and HEAD_DIM == 128:
80+
BM: tl.constexpr = acc.shape[0]
81+
BN: tl.constexpr = acc.shape[1]
82+
acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
83+
acc0 = acc0 * alpha[:, None]
84+
acc1 = acc1 * alpha[:, None]
85+
acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN])
86+
else:
87+
acc = acc * alpha[:, None]
8088
# prepare p and v for the dot
8189
v = desc_v.load([offsetkv_y, 0])
8290
p = p.to(dtype)
@@ -119,7 +127,7 @@ def _host_descriptor_pre_hook(nargs):
119127
if "PYTEST_VERSION" in os.environ:
120128
# Use a single config in testing for reproducibility
121129
configs = [
122-
triton.Config(dict(BLOCK_M=64, BLOCK_N=64), num_stages=2, num_warps=4, pre_hook=_host_descriptor_pre_hook),
130+
triton.Config(dict(BLOCK_M=128, BLOCK_N=64), num_stages=2, num_warps=4, pre_hook=_host_descriptor_pre_hook),
123131
]
124132

125133

@@ -505,7 +513,10 @@ def grid(META):
505513

506514
ctx.grid = grid
507515
if is_cuda() and warp_specialize:
508-
extra_kern_args["maxnreg"] = 80
516+
if HEAD_DIM_K == 128 and q.dtype == torch.float16:
517+
extra_kern_args["maxnreg"] = 168
518+
else:
519+
extra_kern_args["maxnreg"] = 80
509520
_attn_fwd[grid](
510521
sm_scale, M, #
511522
q.shape[0], q.shape[1], #
@@ -620,36 +631,35 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, dtype=torch.float16)
620631
HAS_FLASH = False
621632

622633
TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2')
623-
BATCH, N_HEADS, HEAD_DIM = 4, 32, 64
634+
BATCH, N_HEADS = 4, 32
624635
# vary seq length for fixed head and batch=4
625636
configs = []
626-
for mode in ["fwd", "bwd"]:
627-
for causal in [True, False]:
628-
for warp_specialize in [False, True] if is_blackwell() else [False]:
629-
if mode == "bwd" and not causal:
630-
continue
631-
configs.append(
632-
triton.testing.Benchmark(
633-
x_names=["N_CTX"],
634-
x_vals=[2**i for i in range(10, 15)],
635-
line_arg="provider",
636-
line_vals=["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []) +
637-
(["flash"] if HAS_FLASH else []),
638-
line_names=["Triton [FP16]"] + (["Triton [FP8]"] if TORCH_HAS_FP8 else []) +
639-
(["Flash-2"] if HAS_FLASH else []),
640-
styles=[("red", "-"), ("blue", "-"), ("green", "-")],
641-
ylabel="TFLOPS",
642-
plot_name=
643-
f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}-warp_specialize={warp_specialize}",
644-
args={
645-
"H": N_HEADS,
646-
"BATCH": BATCH,
647-
"HEAD_DIM": HEAD_DIM,
648-
"mode": mode,
649-
"causal": causal,
650-
"warp_specialize": warp_specialize,
651-
},
652-
))
637+
for HEAD_DIM in [64, 128]:
638+
for mode in ["fwd", "bwd"]:
639+
for causal in [True, False]:
640+
for warp_specialize in [False, True] if is_blackwell() else [False]:
641+
configs.append(
642+
triton.testing.Benchmark(
643+
x_names=["N_CTX"],
644+
x_vals=[2**i for i in range(10, 15)],
645+
line_arg="provider",
646+
line_vals=["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []) +
647+
(["flash"] if HAS_FLASH else []),
648+
line_names=["Triton [FP16]"] + (["Triton [FP8]"] if TORCH_HAS_FP8 else []) +
649+
(["Flash-2"] if HAS_FLASH else []),
650+
styles=[("red", "-"), ("blue", "-"), ("green", "-")],
651+
ylabel="TFLOPS",
652+
plot_name=
653+
f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}-warp_specialize={warp_specialize}",
654+
args={
655+
"H": N_HEADS,
656+
"BATCH": BATCH,
657+
"HEAD_DIM": HEAD_DIM,
658+
"mode": mode,
659+
"causal": causal,
660+
"warp_specialize": warp_specialize,
661+
},
662+
))
653663

654664

655665
@triton.testing.perf_report(configs)

0 commit comments

Comments
 (0)