Skip to content

Commit d1e0731

Browse files
authored
[TUTORIAL] Remove grouped gemm simulation from 09-persistent-matmul (#5461)
As discussed in the [multi-buffering PR], the persistent matmul should be kept as an apples-to-apples performance comparison. In particular, the existing perf results make tensor-descriptors look bad. With this updated tutorial I get results like (`K=4096, prec=fp8`): ``` ├─ 1278.215 4731.062 cublas [M=8192, N=8192, K=4096] │ └─ nan 4731.062 sm90_xmma_gemm_e4m3e4m3_e4m3f32_f32_tn_n_tilesize128x128x128_warpgroupsize1x1x1_bias_f16_execute_segment_k_off_kernel__5x_cublas ├─ 1208.855 454.774 matmul_kernel [M=8192, N=8192, K=4096] ├─ 1285.360 427.706 matmul_kernel_persistent [M=8192, N=8192, K=4096] ├─ 1330.667 413.143 matmul_kernel_descriptor_persistent [M=8192, N=8192, K=4096] └─ 1347.254 408.057 matmul_kernel_tma_persistent [M=8192, N=8192, K=4096] ``` So on H100 tensor descriptor is a 3.5% flops uplift over the plain persistent matmul vs. 4.8% for host-side TMA. For the same shapes with fp16 I see a 13% uplift from tensor descriptor vs. 13.4% from host-side TMA. [multi-buffering PR]: triton-lang/triton#5290 (comment)
1 parent 3c058ee commit d1e0731

File tree

1 file changed

+10
-49
lines changed

1 file changed

+10
-49
lines changed

python/tutorials/09-persistent-matmul.py

Lines changed: 10 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ def _matmul_launch_metadata(grid, kernel, args):
5050
ret = {}
5151
M, N, K = args["M"], args["N"], args["K"]
5252
ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]"
53-
if "tiles_per_update" in args:
54-
ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}, tiles_per_update={args['tiles_per_update']:02}]"
5553
if "c_ptr" in args:
5654
bytes_per_elem = args["c_ptr"].element_size()
5755
else:
@@ -376,8 +374,7 @@ def matmul_tma_persistent(a, b):
376374

377375

378376
@triton.jit(launch_metadata=_matmul_launch_metadata)
379-
def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, #
380-
a_ptr, b_ptr, c_ptr, #
377+
def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, #
381378
M, N, K, #
382379
BLOCK_SIZE_M: tl.constexpr, #
383380
BLOCK_SIZE_N: tl.constexpr, #
@@ -417,7 +414,6 @@ def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, #
417414

418415
tile_id = start_pid - NUM_SMS
419416
ki = -1
420-
ni = -1
421417

422418
pid_m = 0
423419
pid_n = 0
@@ -427,36 +423,10 @@ def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, #
427423
num_pid_in_group = GROUP_SIZE_M * num_pid_n
428424

429425
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
430-
# Create an opaque value to prevent the descriptor creation from being
431-
# hoisted out of the loop
432-
zero = tl.inline_asm_elementwise("mov.b32 $0, 0;", "=r", [], dtype=tl.int32, is_pure=True, pack=1)
433426

434427
for _ in range(0, k_tiles * tiles_per_SM):
435428
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
436429
if ki == 0:
437-
ni += 1
438-
439-
# Simulate a grouped gemm
440-
if ni == tiles_per_update:
441-
a_desc = tl._experimental_make_tensor_descriptor(
442-
a_ptr + zero,
443-
shape=[M, K],
444-
strides=[K, 1],
445-
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
446-
)
447-
b_desc = tl._experimental_make_tensor_descriptor(
448-
b_ptr + zero,
449-
shape=[N, K],
450-
strides=[K, 1],
451-
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
452-
)
453-
c_desc = tl._experimental_make_tensor_descriptor(
454-
c_ptr + zero,
455-
shape=[M, N],
456-
strides=[N, 1],
457-
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
458-
)
459-
ni = 0
460430

461431
tile_id += NUM_SMS
462432
group_id = tile_id // num_pid_in_group
@@ -482,8 +452,7 @@ def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, #
482452
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
483453

484454

485-
def matmul_descriptor_persistent(a, b, tiles_per_update):
486-
# Autotuner does not work with TMA. Use manual config.
455+
def matmul_descriptor_persistent(a, b):
487456
configs = {
488457
torch.float8_e4m3fn: {
489458
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, "num_stages": 4,
@@ -513,7 +482,6 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
513482

514483
grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), )
515484
matmul_kernel_descriptor_persistent[grid](
516-
tiles_per_update, #
517485
a, b, c, #
518486
M, N, K, #
519487
BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], #
@@ -570,7 +538,7 @@ def bench_fn(reps, warmup_reps, fn, *args):
570538
fn(*args)
571539

572540

573-
def bench(K, dtype, tiles_per_update, reps=1000, warmup_reps=10000):
541+
def bench(K, dtype, reps=1000, warmup_reps=10000):
574542
M = 8192
575543
N = 8192
576544
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
@@ -586,10 +554,10 @@ def bench(K, dtype, tiles_per_update, reps=1000, warmup_reps=10000):
586554
bench_fn(reps, warmup_reps, matmul_persistent, a, b.T)
587555
if supports_tma():
588556
bench_fn(reps, warmup_reps, matmul_tma_persistent, a, b)
589-
bench_fn(reps, warmup_reps, matmul_descriptor_persistent, a, b, tiles_per_update)
557+
bench_fn(reps, warmup_reps, matmul_descriptor_persistent, a, b)
590558

591559

592-
def validate(M, N, K, dtype, tiles_per_update):
560+
def validate(M, N, K, dtype):
593561
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
594562
b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype)
595563
b = b.T.contiguous()
@@ -599,7 +567,7 @@ def validate(M, N, K, dtype, tiles_per_update):
599567
naive_result = matmul(a, b.T)
600568
persistent_result = matmul_persistent(a, b.T)
601569
tma_persistent_result = matmul_tma_persistent(a, b) if supports_tma() else None
602-
descriptor_persistent_result = matmul_descriptor_persistent(a, b, tiles_per_update) if supports_tma() else None
570+
descriptor_persistent_result = matmul_descriptor_persistent(a, b) if supports_tma() else None
603571

604572
if torch_result is not None:
605573
naive_vs_torch = "✅" if torch.allclose(naive_result.to(torch.float16), torch_result.to(torch.float16),
@@ -624,7 +592,7 @@ def validate(M, N, K, dtype, tiles_per_update):
624592
if tma_persistent_result is not None:
625593
print(f"TMA persistent: {naive_vs_tma_persistent} ", end="")
626594
if descriptor_persistent_result is not None:
627-
print(f"Device TMA persistent: {naive_vs_descriptor_persistent} ", end="")
595+
print(f"Tensor descriptor persistent: {naive_vs_descriptor_persistent} ", end="")
628596
print()
629597

630598

@@ -644,13 +612,6 @@ def show_profile(precision, profile_name):
644612
parser.add_argument("-K", type=int, required=False, default=512)
645613
parser.add_argument("--K_range", type=int, nargs=2)
646614
parser.add_argument("--K_step", type=int, default=512)
647-
parser.add_argument(
648-
"--tiles_per_update",
649-
type=int,
650-
default=1,
651-
help=
652-
"Number of output tiles calculated for each update of the tma descriptor in matmul_descriptor_persistent_kernel",
653-
)
654615
parser.add_argument("--prec", type=str, choices=["fp8", "fp16"], default="fp16")
655616
args = parser.parse_args()
656617

@@ -666,11 +627,11 @@ def show_profile(precision, profile_name):
666627

667628
torch.manual_seed(0)
668629

669-
validate(32, 32, 32, dtype, args.tiles_per_update)
670-
validate(8192, 8192, 512, dtype, args.tiles_per_update)
630+
validate(32, 32, 32, dtype)
631+
validate(8192, 8192, 512, dtype)
671632

672633
proton.start("matmul", hook="triton")
673634
for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step):
674-
bench(K, dtype, args.tiles_per_update)
635+
bench(K, dtype)
675636
proton.finalize()
676637
show_profile(args.prec, "matmul")

0 commit comments

Comments
 (0)