Skip to content

Commit 144c7dc

Browse files
[TUTORIAL] Measuring performance in persistent kernels tutorial in stable thermal state (triton-lang#5042)
Following the nvidia's recipe for measuring performance in 09-persistent-matmul.py tutorial: get system into a stable thermal state by using long warmup run, then do 1000 runs of benchmark. We couldn't done it in the beginning because creating and passing TMA descriptors was creating GPU bubble that allowed GPU to cool down, thus not reaching equilibrium, skewing TMA kernel results towards unfair higher scores. With changes around passing descriptors via grid constants I see results very close to the version with descriptor re-use, so we can now use this methodology and get correct benchmarking results. Example cmd line for measuring perf of fp8 matmul across K=[512, 8192]: `python 09-persistent-matmul.py --prec fp8 --K_range 512 8192`
1 parent a273986 commit 144c7dc

File tree

1 file changed

+27
-26
lines changed

1 file changed

+27
-26
lines changed

python/tutorials/09-persistent-matmul.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@
2020
"""
2121

2222
import argparse
23-
import time
2423

2524
import torch
2625
import triton
2726
import triton.language as tl
2827
import triton.tools.experimental_descriptor
2928
import triton.profiler as proton
29+
from contextlib import contextmanager
3030

3131
if torch.cuda.is_available():
3232
from triton._C.libtriton import nvidia
@@ -48,6 +48,8 @@ def _matmul_launch_metadata(grid, kernel, args):
4848
ret = {}
4949
M, N, K = args["M"], args["N"], args["K"]
5050
ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]"
51+
if "tiles_per_update" in args:
52+
ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}, tiles_per_update={args['tiles_per_update']:02}]"
5153
if "c_ptr" in args:
5254
bytes_per_elem = args["c_ptr"].element_size()
5355
else:
@@ -541,41 +543,40 @@ def torch_matmul(a, b):
541543
return c
542544

543545

544-
def bench(K, dtype, tiles_per_update, reps=10):
546+
@contextmanager
547+
def proton_context():
548+
proton.activate(0)
549+
try:
550+
yield
551+
finally:
552+
proton.deactivate(0)
553+
554+
555+
def bench_fn(reps, warmup_reps, fn, *args):
556+
for _ in range(warmup_reps):
557+
fn(*args)
558+
with proton_context():
559+
for _ in range(reps):
560+
fn(*args)
561+
562+
563+
def bench(K, dtype, tiles_per_update, reps=1000, warmup_reps=10000):
545564
M = 8192
546565
N = 8192
547566
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
548567
b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype)
549568

550569
b = b.T.contiguous()
551570

552-
proton.activate(0)
553-
554571
if cublas is not None:
555-
for _ in range(reps):
556-
cublas_matmul(a, b)
557-
time.sleep(0.01)
572+
bench_fn(reps, warmup_reps, cublas_matmul, a, b)
558573
if dtype == torch.float16:
559-
for _ in range(reps):
560-
torch_matmul(a, b)
561-
time.sleep(0.01)
562-
for _ in range(reps):
563-
matmul(a, b.T)
564-
time.sleep(0.01)
565-
for _ in range(reps):
566-
matmul_persistent(a, b.T)
567-
time.sleep(0.01)
574+
bench_fn(reps, warmup_reps, torch_matmul, a, b)
575+
bench_fn(reps, warmup_reps, matmul, a, b.T)
576+
bench_fn(reps, warmup_reps, matmul_persistent, a, b.T)
568577
if supports_tma():
569-
for _ in range(reps):
570-
matmul_tma_persistent(a, b)
571-
time.sleep(0.01)
572-
with proton.scope(
573-
f"matmul_kernel_device_tma_persistent [M={M}, N={N}, K={K}, tiles_per_update={tiles_per_update:02}]"):
574-
for _ in range(reps):
575-
matmul_device_tma_persistent(a, b, tiles_per_update)
576-
time.sleep(0.01)
577-
578-
proton.deactivate(0)
578+
bench_fn(reps, warmup_reps, matmul_tma_persistent, a, b)
579+
bench_fn(reps, warmup_reps, matmul_device_tma_persistent, a, b, tiles_per_update)
579580

580581

581582
def validate(M, N, K, dtype, tiles_per_update):

0 commit comments

Comments
 (0)