Skip to content

Commit 17966f4

Browse files
authored
[TUTORIAL] Cleanup persistent matmul tutorial (triton-lang#6374)
This changes the tutorial to try both warp-specialized and non-warp-specialized configs when available. I also updated the verification and benchmark steps to print out incremental progress updates so you can tell the scripts hasn't hanged. Example output from the middle of a run: ``` M=32, N=32, K=32, verification naive vs: Torch: ⭕ cuBLAS: ✅ Persistent: ✅ TMA (warp_specialize=False): ✅ TMA (warp_specialize=True): ✅ TMA Persistent (warp_specialize=False): ✅ TMA Persistent (warp_specialize=True): ✅ Tensor Descriptor Persistent (warp_specialize=False): ✅ Tensor Descriptor Persistent (warp_specialize=True): ... ``` We should probably rename this advanced-matmul at this point tbh.
1 parent c23e300 commit 17966f4

File tree

1 file changed

+92
-76
lines changed

1 file changed

+92
-76
lines changed

python/tutorials/09-persistent-matmul.py

Lines changed: 92 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"""
2121

2222
import argparse
23+
import itertools
2324

2425
import torch
2526
import triton
@@ -46,10 +47,15 @@ def supports_tma():
4647
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
4748

4849

50+
def supports_ws():
51+
return is_cuda() and torch.cuda.get_device_capability()[0] >= 10
52+
53+
4954
def _matmul_launch_metadata(grid, kernel, args):
5055
ret = {}
51-
M, N, K = args["M"], args["N"], args["K"]
52-
ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]"
56+
M, N, K, WS = args["M"], args["N"], args["K"], args.get("WARP_SPECIALIZE", False)
57+
ws_str = "_ws" if WS else ""
58+
ret["name"] = f"{kernel.name}{ws_str} [M={M}, N={N}, K={K}]"
5359
if "c_ptr" in args:
5460
bytes_per_elem = args["c_ptr"].element_size()
5561
else:
@@ -61,6 +67,7 @@ def _matmul_launch_metadata(grid, kernel, args):
6167

6268
HAS_TMA_DESC = supports_tma() and hasattr(tl, "nv_tma_desc_type")
6369
HAS_TENSOR_DESC = supports_tma() and hasattr(tl, "make_tensor_descriptor")
70+
HAS_WARP_SPECIALIZE = supports_ws() and HAS_TENSOR_DESC
6471

6572

6673
# TmaAutoTuneHelper used in htyu's PR #5622
@@ -197,17 +204,18 @@ def matmul(a, b):
197204

198205
@triton.autotune(
199206
configs=matmul_get_configs(),
200-
key=["M", "N", "K"],
207+
key=["M", "N", "K", "WARP_SPECIALIZE"],
201208
)
202209
@triton.jit(launch_metadata=_matmul_launch_metadata)
203-
def matmul_kernel_tma_ws(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
204-
M, N, K, #
205-
BLOCK_SIZE_M: tl.constexpr, #
206-
BLOCK_SIZE_N: tl.constexpr, #
207-
BLOCK_SIZE_K: tl.constexpr, #
208-
GROUP_SIZE_M: tl.constexpr, #
209-
FP8_OUTPUT: tl.constexpr, #
210-
):
210+
def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
211+
M, N, K, #
212+
BLOCK_SIZE_M: tl.constexpr, #
213+
BLOCK_SIZE_N: tl.constexpr, #
214+
BLOCK_SIZE_K: tl.constexpr, #
215+
GROUP_SIZE_M: tl.constexpr, #
216+
FP8_OUTPUT: tl.constexpr, #
217+
WARP_SPECIALIZE: tl.constexpr, #
218+
):
211219
dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16
212220

213221
pid = tl.program_id(axis=0)
@@ -227,7 +235,7 @@ def matmul_kernel_tma_ws(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
227235

228236
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
229237

230-
for k in tl.range(k_tiles, warp_specialize=True, num_stages=3):
238+
for k in tl.range(k_tiles, warp_specialize=WARP_SPECIALIZE):
231239
offs_k = k * BLOCK_SIZE_K
232240
a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype)
233241
b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype)
@@ -240,7 +248,7 @@ def matmul_kernel_tma_ws(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
240248
tl._experimental_descriptor_store(c_desc_ptr, c, [offs_cm, offs_cn])
241249

242250

243-
def matmul_tma_ws(a, b):
251+
def matmul_tma(a, b, warp_specialize: bool):
244252
# Check constraints.
245253
assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed
246254
assert a.dtype == b.dtype, "Incompatible dtypes"
@@ -296,10 +304,11 @@ def grid(META):
296304
desc_b = desc_helper.get_tma_descriptor_kernel_param("b")
297305
desc_c = desc_helper.get_tma_descriptor_kernel_param("c")
298306

299-
matmul_kernel_tma_ws[grid](
307+
matmul_kernel_tma[grid](
300308
desc_a, desc_b, desc_c, #
301309
M, N, K, #
302310
FP8_OUTPUT=dtype == torch.float8_e4m3fn, #
311+
WARP_SPECIALIZE=warp_specialize, #
303312
)
304313
return c
305314

@@ -402,19 +411,23 @@ def matmul_persistent(a, b):
402411

403412
def matmul_tma_persistent_get_configs():
404413
return [
405-
triton.Config({'BLOCK_SIZE_M': BM, 'BLOCK_SIZE_N': BN, "BLOCK_SIZE_K" : BK, "GROUP_SIZE_M" : 8, "EPILOGUE_SUBTILE" : SUBTILE}, num_stages=s, num_warps=w) \
406-
for BM in [128] \
407-
for BN in [128, 256] \
408-
for BK in [64, 128] \
409-
for s in ([2, 3, 4]) \
410-
for w in [4, 8] \
411-
for SUBTILE in [True, False] \
414+
triton.Config(
415+
{
416+
'BLOCK_SIZE_M': BM, 'BLOCK_SIZE_N': BN, "BLOCK_SIZE_K": BK, "GROUP_SIZE_M": 8, "EPILOGUE_SUBTILE":
417+
SUBTILE
418+
}, num_stages=s, num_warps=w) #
419+
for BM in [128] #
420+
for BN in [128, 256] #
421+
for BK in [64, 128] #
422+
for s in ([2, 3, 4]) #
423+
for w in [4, 8] #
424+
for SUBTILE in [True, False] #
412425
]
413426

414427

415428
@triton.autotune(
416429
configs=matmul_tma_persistent_get_configs(),
417-
key=["M", "N", "K"],
430+
key=["M", "N", "K", "WARP_SPECIALIZE"],
418431
)
419432
@triton.jit(launch_metadata=_matmul_launch_metadata)
420433
def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
@@ -425,7 +438,9 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
425438
GROUP_SIZE_M: tl.constexpr, #
426439
FP8_OUTPUT: tl.constexpr, #
427440
EPILOGUE_SUBTILE: tl.constexpr, #
428-
NUM_SMS: tl.constexpr): #
441+
NUM_SMS: tl.constexpr, #
442+
WARP_SPECIALIZE: tl.constexpr, #
443+
):
429444
dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16
430445
start_pid = tl.program_id(axis=0)
431446
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
@@ -439,7 +454,7 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
439454
# Enable warp specialization to leverage async warp scheduling in the GPU.
440455
# FIXME: This only works on Blackwell right now. On older GPUs, this will
441456
# use software pipelining.
442-
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=True):
457+
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=WARP_SPECIALIZE):
443458
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS)
444459
offs_am = pid_m * BLOCK_SIZE_M
445460
offs_bn = pid_n * BLOCK_SIZE_N
@@ -473,7 +488,7 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
473488
tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am_c, offs_bn_c])
474489

475490

476-
def matmul_tma_persistent(a, b):
491+
def matmul_tma_persistent(a, b, warp_specialize: bool):
477492
# Check constraints.
478493
assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed
479494
assert a.dtype == b.dtype, "Incompatible dtypes"
@@ -542,13 +557,14 @@ def grid(META):
542557
M, N, K, #
543558
FP8_OUTPUT=dtype == torch.float8_e4m3fn, #
544559
NUM_SMS=NUM_SMS, #
560+
WARP_SPECIALIZE=warp_specialize, #
545561
)
546562
return c
547563

548564

549565
@triton.autotune(
550566
configs=matmul_tma_persistent_get_configs(),
551-
key=["M", "N", "K"],
567+
key=["M", "N", "K", "WARP_SPECIALIZE"],
552568
)
553569
@triton.jit(launch_metadata=_matmul_launch_metadata)
554570
def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, #
@@ -558,7 +574,9 @@ def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, #
558574
BLOCK_SIZE_K: tl.constexpr, #
559575
GROUP_SIZE_M: tl.constexpr, #
560576
EPILOGUE_SUBTILE: tl.constexpr, #
561-
NUM_SMS: tl.constexpr): #
577+
NUM_SMS: tl.constexpr, #
578+
WARP_SPECIALIZE: tl.constexpr, #
579+
):
562580
# Matmul using TMA and device-side descriptor creation
563581
dtype = c_ptr.dtype.element_ty
564582
start_pid = tl.program_id(axis=0)
@@ -591,7 +609,7 @@ def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, #
591609
tile_id_c = start_pid - NUM_SMS
592610
num_pid_in_group = GROUP_SIZE_M * num_pid_n
593611

594-
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):
612+
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=WARP_SPECIALIZE):
595613
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS)
596614
offs_am = pid_m * BLOCK_SIZE_M
597615
offs_bn = pid_n * BLOCK_SIZE_N
@@ -621,7 +639,7 @@ def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, #
621639
c_desc.store([offs_cm, offs_cn], c)
622640

623641

624-
def matmul_descriptor_persistent(a, b):
642+
def matmul_descriptor_persistent(a, b, warp_specialize: bool):
625643
# Check constraints.
626644
assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed
627645
assert a.dtype == b.dtype, "Incompatible dtypes"
@@ -644,6 +662,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
644662
a, b, c, #
645663
M, N, K, #
646664
NUM_SMS=NUM_SMS, #
665+
WARP_SPECIALIZE=warp_specialize, #
647666
)
648667
return c
649668

@@ -683,12 +702,14 @@ def proton_context():
683702
proton.deactivate(0)
684703

685704

686-
def bench_fn(reps, warmup_reps, fn, *args):
705+
def bench_fn(label, reps, warmup_reps, fn, *args):
706+
print(f"Benchmarking {label}: ...", end="")
687707
for _ in range(warmup_reps):
688708
fn(*args)
689709
with proton_context():
690710
for _ in range(reps):
691711
fn(*args)
712+
print(f"\rBenchmarking {label}: done")
692713

693714

694715
def bench(K, dtype, reps=10000, warmup_reps=10000):
@@ -700,60 +721,55 @@ def bench(K, dtype, reps=10000, warmup_reps=10000):
700721
b = b.T.contiguous()
701722

702723
if cublas is not None:
703-
bench_fn(reps, warmup_reps, cublas_matmul, a, b)
724+
bench_fn("cublas", reps, 1, cublas_matmul, a, b)
704725
if dtype == torch.float16:
705-
bench_fn(reps, warmup_reps, torch_matmul, a, b)
706-
bench_fn(reps, warmup_reps, matmul, a, b.T)
707-
bench_fn(reps, warmup_reps, matmul_persistent, a, b.T)
708-
if HAS_TMA_DESC:
709-
bench_fn(reps, warmup_reps, matmul_tma_persistent, a, b)
710-
if HAS_TENSOR_DESC:
711-
bench_fn(reps, warmup_reps, matmul_descriptor_persistent, a, b)
712-
bench_fn(reps, warmup_reps, matmul_tma_ws, a, b)
726+
bench_fn("torch", reps, warmup_reps, torch_matmul, a, b)
727+
bench_fn("naive", reps, warmup_reps, matmul, a, b.T)
728+
bench_fn("persistent", reps, warmup_reps, matmul_persistent, a, b.T)
729+
warp_specialize = [False, True] if HAS_WARP_SPECIALIZE else [False]
730+
for ws in warp_specialize:
731+
ws_str = "_ws" if ws else ""
732+
if HAS_TMA_DESC:
733+
bench_fn(f"tma_persistent{ws_str}", reps, warmup_reps, lambda a, b: matmul_tma_persistent(a, b, ws), a, b)
734+
bench_fn(f"tma{ws_str}", reps, warmup_reps, lambda a, b: matmul_tma(a, b, ws), a, b)
735+
if HAS_TENSOR_DESC:
736+
bench_fn(f"descriptor_persistent{ws_str}", reps, warmup_reps,
737+
lambda a, b: matmul_descriptor_persistent(a, b, ws), a, b)
738+
739+
740+
def run_test(expect, fn, a, b, label, enabled=True):
741+
print(f" {label}: ...", end="")
742+
if enabled:
743+
actual = fn(a, b)
744+
passed = torch.allclose(expect, actual.to(expect.dtype), atol=1.0)
745+
icon = "✅" if passed else "❌"
746+
else:
747+
icon = "⭕"
748+
print(f"\r {label}: {icon} ")
713749

714750

715751
def validate(M, N, K, dtype):
752+
print(f"{M=}, {N=}, {K=}, verification naive vs: ")
716753
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
717754
b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype)
718755
b = b.T.contiguous()
719756

720-
torch_result = torch_matmul(a, b) if dtype == torch.float16 else None
721-
cublas_result = cublas_matmul(a, b) if cublas is not None else None
722-
naive_result = matmul(a, b.T)
723-
tma_ws_result = matmul_tma_ws(a, b) if HAS_TENSOR_DESC else None
724-
persistent_result = matmul_persistent(a, b.T)
725-
tma_persistent_result = matmul_tma_persistent(a, b) if HAS_TMA_DESC else None
726-
descriptor_persistent_result = matmul_descriptor_persistent(a, b) if HAS_TENSOR_DESC else None
727-
728-
if tma_ws_result is not None:
729-
naive_vs_tma_ws = "✅" if torch.allclose(naive_result.to(torch.float16), tma_ws_result.to(torch.float16),
730-
atol=1.0) else "❌"
731-
if torch_result is not None:
732-
naive_vs_torch = "✅" if torch.allclose(naive_result.to(torch.float16), torch_result.to(torch.float16),
733-
atol=1.0) else "❌"
734-
if cublas_result is not None:
735-
naive_vs_cublas = "✅" if torch.allclose(naive_result.to(torch.float16), cublas_result.to(torch.float16),
736-
atol=1.0) else "❌"
737-
naive_vs_persistent = "✅" if torch.allclose(naive_result.to(torch.float16), persistent_result.to(torch.float16),
738-
atol=1.0) else "❌"
739-
if tma_persistent_result is not None:
740-
naive_vs_tma_persistent = "✅" if torch.allclose(cublas_result.to(torch.float16),
741-
tma_persistent_result.to(torch.float16), atol=1.0) else "❌"
742-
if descriptor_persistent_result is not None:
743-
naive_vs_descriptor_persistent = "✅" if torch.allclose(cublas_result.to(
744-
torch.float16), descriptor_persistent_result.to(torch.float16), atol=1.0) else "❌"
745-
print(f"M={M}, N={N}, K={K} verification naive vs: ", end="")
746-
if tma_ws_result is not None:
747-
print(f"tma: {naive_vs_tma_ws} ", end="")
748-
if torch_result is not None:
749-
print(f"torch: {naive_vs_torch} ", end="")
750-
if cublas_result is not None:
751-
print(f"cublas: {naive_vs_cublas} ", end="")
752-
print(f"persistent: {naive_vs_persistent} ", end="")
753-
if tma_persistent_result is not None:
754-
print(f"TMA persistent: {naive_vs_tma_persistent} ", end="")
755-
if descriptor_persistent_result is not None:
756-
print(f"Tensor descriptor persistent: {naive_vs_descriptor_persistent} ", end="")
757+
naive_result = matmul(a, b.T).to(torch.float16)
758+
run_test(naive_result, torch_matmul, a, b, "Torch", enabled=dtype == torch.float16)
759+
run_test(naive_result, cublas_matmul, a, b, "cuBLAS", enabled=cublas is not None)
760+
run_test(naive_result, matmul_persistent, a, b.T, "Persistent")
761+
762+
kernels = [
763+
(matmul_tma, "TMA", HAS_TMA_DESC),
764+
(matmul_tma_persistent, "TMA Persistent", HAS_TMA_DESC),
765+
(matmul_descriptor_persistent, "Tensor Descriptor Persistent", HAS_TENSOR_DESC),
766+
]
767+
warp_specialize = [False, True] if HAS_WARP_SPECIALIZE else [False]
768+
769+
for (kernel, label, enabled), warp_specialize in itertools.product(kernels, warp_specialize):
770+
label = f"{label} (warp_specialize={warp_specialize})"
771+
enabled = enabled and (not warp_specialize or HAS_TENSOR_DESC)
772+
run_test(naive_result, lambda a, b: kernel(a, b, warp_specialize), a, b, label, enabled)
757773
print()
758774

759775

0 commit comments

Comments
 (0)