|
1 | 1 | """ |
2 | | -Persistent FP8 Matmul |
| 2 | +Persistent Matmul |
3 | 3 | ===================== |
4 | 4 | This script demonstrates persistent kernel implementations of matrix multiplication using Triton. |
5 | | -It includes various matmul methods, such as naive, persistent, and TMA (Tensor Memory Accelerator) based approaches, and only supports GPUs with compute capability >= 9.0. |
6 | | -Triton and CuBLAS implementations are benchmarked under different configurations and evaluated using the proton profiler. |
| 5 | +Various matmul methods are included, such as naive, persistent, and TMA (Tensor Memory Accelerator) based approaches. |
| 6 | +The kernels support both FP16 and FP8 data types but the FP8 implementation is only available on CUDA devices with compute capability >= 9.0. |
| 7 | +
|
| 8 | +Triton and cuBLAS implementations are benchmarked under different configurations and evaluated using the proton profiler. |
7 | 9 | Users can pass command-line arguments to specify matrix dimensions and iteration steps flexibly. |
| 10 | +
|
| 11 | +.. code-block:: bash |
| 12 | +
|
| 13 | + # FP8 |
| 14 | + python 09-persistent-matmul.py --prec fp8 --K_range 128 1024 --K_step 128 |
| 15 | +
|
| 16 | + # FP16 |
| 17 | + python 09-persistent-matmul.py --prec fp16 --K_range 128 1024 --K_step 128 |
| 18 | +
|
| 19 | +Note that currently this tutorial will fail on devices with a small shared memory size, such as RTX-4090. |
8 | 20 | """ |
9 | 21 |
|
10 | 22 | import argparse |
@@ -36,12 +48,12 @@ def _matmul_launch_metadata(grid, kernel, args): |
36 | 48 | ret = {} |
37 | 49 | M, N, K = args["M"], args["N"], args["K"] |
38 | 50 | ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]" |
39 | | - ret["flops8"] = 2. * M * N * K |
40 | 51 | if "c_ptr" in args: |
41 | 52 | bytes_per_elem = args["c_ptr"].element_size() |
42 | 53 | else: |
43 | 54 | bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 |
44 | | - ret["bytes"] = bytes_per_elem * (M * K + N * K) |
| 55 | + ret[f"flops{bytes_per_elem * 8}"] = 2. * M * N * K |
| 56 | + ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N) |
45 | 57 | return ret |
46 | 58 |
|
47 | 59 |
|
@@ -328,7 +340,7 @@ def matmul_tma_persistent(a, b): |
328 | 340 | N, K = b.shape |
329 | 341 | dtype = a.dtype |
330 | 342 |
|
331 | | - c = torch.zeros((M, N), device=a.device, dtype=dtype) |
| 343 | + c = torch.empty((M, N), device=a.device, dtype=dtype) |
332 | 344 | desc_a = triton.tools.experimental_descriptor.create_2d_tma_descriptor(a.data_ptr(), M, K, |
333 | 345 | configs[dtype]["BLOCK_SIZE_M"], |
334 | 346 | configs[dtype]["BLOCK_SIZE_K"], |
@@ -481,7 +493,7 @@ def matmul_device_tma_persistent(a, b, tiles_per_update): |
481 | 493 | N, K = b.shape |
482 | 494 | dtype = a.dtype |
483 | 495 |
|
484 | | - c = torch.zeros((M, N), device=a.device, dtype=dtype) |
| 496 | + c = torch.empty((M, N), device=a.device, dtype=dtype) |
485 | 497 | NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count |
486 | 498 | tma_size = 128 |
487 | 499 | workspace = torch.empty(NUM_SMS * 3 * tma_size, dtype=torch.uint8, device="cuda") |
@@ -511,21 +523,20 @@ def cublas_matmul(a, b): |
511 | 523 | dtype = a.dtype |
512 | 524 | c = torch.empty((M, N), device=a.device, dtype=dtype) |
513 | 525 | bytes_per_elem = a.element_size() |
514 | | - flops_str = "flops8" if dtype == torch.float8_e4m3fn else "flops" |
515 | | - with proton.scope(f"cublas M={M}, N={N}, K={K}", |
516 | | - {"bytes": bytes_per_elem * (M * K + N * K), flops_str: 2. * M * N * K}): |
| 526 | + flops_str = f"flops{bytes_per_elem * 8}" |
| 527 | + with proton.scope(f"cublas [M={M}, N={N}, K={K}]", |
| 528 | + {"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}): |
517 | 529 | cublas.matmul(a, b, c) |
518 | 530 | return c |
519 | 531 |
|
520 | 532 |
|
521 | 533 | def torch_matmul(a, b): |
522 | 534 | M, K = a.shape |
523 | 535 | N, K = b.shape |
524 | | - dtype = a.dtype |
525 | 536 | bytes_per_elem = a.element_size() |
526 | | - flops_str = "flops8" if dtype == torch.float8_e4m3fn else "flops" |
527 | | - with proton.scope(f"torch M={M}, N={N}, K={K}", |
528 | | - {"bytes": bytes_per_elem * (M * K + N * K), flops_str: 2. * M * N * K}): |
| 537 | + flops_str = f"flops{bytes_per_elem * 8}" |
| 538 | + with proton.scope(f"torch [M={M}, N={N}, K={K}]", |
| 539 | + {"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}): |
529 | 540 | c = torch.matmul(a, b.T) |
530 | 541 | return c |
531 | 542 |
|
@@ -558,10 +569,8 @@ def bench(K, dtype, tiles_per_update, reps=10): |
558 | 569 | for _ in range(reps): |
559 | 570 | matmul_tma_persistent(a, b) |
560 | 571 | time.sleep(0.01) |
561 | | - flops_str = "flops8" if dtype == torch.float8_e4m3fn else "flops" |
562 | 572 | with proton.scope( |
563 | | - f"matmul_kernel_device_tma_persistent M={M}, N={N}, K={K}, tiles_per_update={tiles_per_update:02}", |
564 | | - {"bytes": a.element_size() * (M * K + N * K), flops_str: 2. * M * N * K}): |
| 573 | + f"matmul_kernel_device_tma_persistent [M={M}, N={N}, K={K}, tiles_per_update={tiles_per_update:02}]"): |
565 | 574 | for _ in range(reps): |
566 | 575 | matmul_device_tma_persistent(a, b, tiles_per_update) |
567 | 576 | time.sleep(0.01) |
@@ -608,6 +617,17 @@ def validate(M, N, K, dtype, tiles_per_update): |
608 | 617 | print() |
609 | 618 |
|
610 | 619 |
|
| 620 | +def show_profile(precision, profile_name): |
| 621 | + import triton.profiler.viewer as proton_viewer |
| 622 | + metrics = ["time/ms"] |
| 623 | + if precision == 'fp8': |
| 624 | + metrics = ["tflop8/s"] + metrics |
| 625 | + elif precision == 'fp16': |
| 626 | + metrics = ["tflop16/s"] + metrics |
| 627 | + file_name = f"{profile_name}.hatchet" |
| 628 | + proton_viewer.parse(metrics, file_name, depth=100) |
| 629 | + |
| 630 | + |
611 | 631 | if __name__ == "__main__": |
612 | 632 | parser = argparse.ArgumentParser() |
613 | 633 | parser.add_argument("-K", type=int, required=False, default=512) |
@@ -642,3 +662,4 @@ def validate(M, N, K, dtype, tiles_per_update): |
642 | 662 | for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): |
643 | 663 | bench(K, dtype, args.tiles_per_update) |
644 | 664 | proton.finalize() |
| 665 | + show_profile(args.prec, "matmul") |
0 commit comments