Skip to content

Commit a26848c

Browse files
authored
[Tutorial] Use per-SM descriptors in matmul tutorial (triton-lang#4682)
1 parent 09675e5 commit a26848c

File tree

2 files changed

+71
-34
lines changed

2 files changed

+71
-34
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,6 @@ docs/sg_execution_times.rst
6565

6666
# Vim
6767
*.swp
68+
69+
# macOS
70+
.DS_Store

python/tutorials/09-persistent-matmul.py

Lines changed: 68 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -360,9 +360,9 @@ def matmul_tma_persistent(a, b):
360360

361361

362362
@triton.jit(launch_metadata=_matmul_launch_metadata)
363-
def matmul_kernel_device_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
363+
def matmul_kernel_device_tma_persistent(workspace_ptr, #
364+
tiles_per_update: tl.constexpr, #
364365
a_ptr, b_ptr, c_ptr, #
365-
ready_flag, #
366366
M, N, K, #
367367
BLOCK_SIZE_M: tl.constexpr, #
368368
BLOCK_SIZE_N: tl.constexpr, #
@@ -377,31 +377,32 @@ def matmul_kernel_device_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
377377
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
378378
num_tiles = num_pid_m * num_pid_n
379379

380-
if start_pid == 0:
381-
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr,
382-
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], global_size=[M, K],
383-
element_ty=a_ptr.dtype.element_ty)
384-
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr,
385-
load_size=[BLOCK_SIZE_N, BLOCK_SIZE_K], global_size=[N, K],
386-
element_ty=b_ptr.dtype.element_ty)
387-
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr,
388-
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], global_size=[M, N],
389-
element_ty=c_ptr.dtype.element_ty)
390-
tl.atomic_xchg(ready_flag, 1, sem="release")
391-
else:
392-
flag = tl.full([], 0, tl.int32)
393-
while flag != 1:
394-
flag = tl.atomic_add(ready_flag, 0, sem="acquire")
395-
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
396-
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
397-
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
380+
TMA_SIZE: tl.constexpr = 128
381+
workspace_base = workspace_ptr + start_pid * 3 * TMA_SIZE
382+
a_desc_ptr = workspace_base
383+
b_desc_ptr = workspace_base + TMA_SIZE
384+
c_desc_ptr = workspace_base + 2 * TMA_SIZE
385+
386+
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr,
387+
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], global_size=[M, K],
388+
element_ty=a_ptr.dtype.element_ty)
389+
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr,
390+
load_size=[BLOCK_SIZE_N, BLOCK_SIZE_K], global_size=[N, K],
391+
element_ty=b_ptr.dtype.element_ty)
392+
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr,
393+
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], global_size=[M, N],
394+
element_ty=c_ptr.dtype.element_ty)
395+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
396+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
397+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
398398

399399
tiles_per_SM = num_tiles // NUM_SMS
400400
if start_pid < num_tiles % NUM_SMS:
401401
tiles_per_SM += 1
402402

403403
tile_id = start_pid - NUM_SMS
404404
ki = -1
405+
ni = -1
405406

406407
pid_m = 0
407408
pid_n = 0
@@ -415,6 +416,27 @@ def matmul_kernel_device_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
415416
for _ in range(0, k_tiles * tiles_per_SM):
416417
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
417418
if ki == 0:
419+
ni += 1
420+
421+
# Simulate a grouped gemm
422+
if ni == tiles_per_update:
423+
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr,
424+
load_size=[BLOCK_SIZE_M,
425+
BLOCK_SIZE_K], global_size=[M, K],
426+
element_ty=a_ptr.dtype.element_ty)
427+
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr,
428+
load_size=[BLOCK_SIZE_N,
429+
BLOCK_SIZE_K], global_size=[N, K],
430+
element_ty=b_ptr.dtype.element_ty)
431+
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr,
432+
load_size=[BLOCK_SIZE_M,
433+
BLOCK_SIZE_N], global_size=[M, N],
434+
element_ty=c_ptr.dtype.element_ty)
435+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
436+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
437+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
438+
ni = 0
439+
418440
tile_id += NUM_SMS
419441
group_id = tile_id // num_pid_in_group
420442
first_pid_m = group_id * GROUP_SIZE_M
@@ -435,10 +457,11 @@ def matmul_kernel_device_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
435457
c = accumulator.to(dtype)
436458

437459
tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn])
460+
438461
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
439462

440463

441-
def matmul_device_tma_persistent(a, b):
464+
def matmul_device_tma_persistent(a, b, tiles_per_update):
442465
# Autotuner does not work with TMA. Use manual config.
443466
configs = {
444467
torch.float8_e4m3fn: {
@@ -459,15 +482,15 @@ def matmul_device_tma_persistent(a, b):
459482
dtype = a.dtype
460483

461484
c = torch.zeros((M, N), device=a.device, dtype=dtype)
462-
a_desc, b_desc, c_desc = [torch.empty(128, dtype=torch.uint8, device="cuda") for _ in range(3)]
463-
ready_flag = torch.zeros((), dtype=torch.int32, device="cuda")
464485
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
486+
tma_size = 128
487+
workspace = torch.empty(NUM_SMS * 3 * tma_size, dtype=torch.uint8, device="cuda")
465488

466489
grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), )
467490
matmul_kernel_device_tma_persistent[grid](
468-
a_desc, b_desc, c_desc, #
491+
workspace, #
492+
tiles_per_update, #
469493
a, b, c, #
470-
ready_flag, #
471494
M, N, K, #
472495
BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], #
473496
BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], #
@@ -507,7 +530,7 @@ def torch_matmul(a, b):
507530
return c
508531

509532

510-
def bench(K, dtype, reps=10):
533+
def bench(K, dtype, tiles_per_update, reps=10):
511534
M = 8192
512535
N = 8192
513536
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
@@ -535,14 +558,18 @@ def bench(K, dtype, reps=10):
535558
for _ in range(reps):
536559
matmul_tma_persistent(a, b)
537560
time.sleep(0.01)
538-
for _ in range(reps):
539-
matmul_device_tma_persistent(a, b)
540-
time.sleep(0.01)
561+
flops_str = "flops8" if dtype == torch.float8_e4m3fn else "flops"
562+
with proton.scope(
563+
f"matmul_kernel_device_tma_persistent M={M}, N={N}, K={K}, tiles_per_update={tiles_per_update:02}",
564+
{"bytes": a.element_size() * (M * K + N * K), flops_str: 2. * M * N * K}):
565+
for _ in range(reps):
566+
matmul_device_tma_persistent(a, b, tiles_per_update)
567+
time.sleep(0.01)
541568

542569
proton.deactivate(0)
543570

544571

545-
def validate(M, N, K, dtype):
572+
def validate(M, N, K, dtype, tiles_per_update):
546573
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
547574
b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype)
548575
b = b.T.contiguous()
@@ -552,7 +579,7 @@ def validate(M, N, K, dtype):
552579
naive_result = matmul(a, b.T)
553580
persistent_result = matmul_persistent(a, b.T)
554581
tma_persistent_result = matmul_tma_persistent(a, b) if supports_tma() else None
555-
device_tma_persistent_result = matmul_device_tma_persistent(a, b) if supports_tma() else None
582+
device_tma_persistent_result = matmul_device_tma_persistent(a, b, tiles_per_update) if supports_tma() else None
556583

557584
if torch_result is not None:
558585
naive_vs_torch = "✅" if torch.allclose(naive_result.to(torch.float16), torch_result.to(torch.float16),
@@ -586,6 +613,13 @@ def validate(M, N, K, dtype):
586613
parser.add_argument("-K", type=int, required=False, default=512)
587614
parser.add_argument("--K_range", type=int, nargs=2)
588615
parser.add_argument("--K_step", type=int, default=512)
616+
parser.add_argument(
617+
"--tiles_per_update",
618+
type=int,
619+
default=1,
620+
help=
621+
"Number of output tiles calculated for each update of the tma descriptor in matmul_device_tma_persistent_kernel",
622+
)
589623
parser.add_argument("--prec", type=str, choices=["fp8", "fp16"], default="fp16")
590624
args = parser.parse_args()
591625

@@ -601,10 +635,10 @@ def validate(M, N, K, dtype):
601635

602636
torch.manual_seed(0)
603637

604-
validate(32, 32, 32, dtype)
605-
validate(8192, 8192, 512, dtype)
638+
validate(32, 32, 32, dtype, args.tiles_per_update)
639+
validate(8192, 8192, 512, dtype, args.tiles_per_update)
606640

607641
proton.start("matmul", hook="triton")
608642
for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step):
609-
bench(K, dtype)
643+
bench(K, dtype, args.tiles_per_update)
610644
proton.finalize()

0 commit comments

Comments
 (0)