Skip to content

Commit e65dd81

Browse files
authored
[TUTORIAL] Multiple improvements to the tutorials, especially to 09-persistent-matmul.py (#4802)
- Format the introduction section in some tutorials. - Add instructions for running the persistent matmul tutorial, as well as instructions for using `proton-viewer`. - Replace `torch.zeros` with `torch.empty` to remove unnecessary GPU kernels. - Add brackets `[` and `]` around shapes to improve the output formatting. - Remove redundant metric accumulation, as the Triton hook already handles metric accumulation.
1 parent 184fb53 commit e65dd81

File tree

6 files changed

+49
-28
lines changed

6 files changed

+49
-28
lines changed

.github/workflows/documentation.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525

2626
- name: Install dependent packages
2727
run: |
28-
sudo pip3 install tabulate cmake sphinx matplotlib myst_parser sphinx-rtd-theme pandas pytest sphinx-gallery sphinx-multiversion
28+
sudo pip3 install tabulate cmake sphinx matplotlib myst_parser sphinx-rtd-theme pandas pytest sphinx-gallery sphinx-multiversion llnl-hatchet
2929
3030
#- name: Fetch dependent branches
3131
# run: |

docs/conf.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,6 @@ def documenter(app, obj, parent):
159159
'examples_dirs': '../python/tutorials/',
160160
'gallery_dirs': 'getting-started/tutorials',
161161
'filename_pattern': '',
162-
# TODO: Re-enable the grouped-gemm tutorial. It currently hits this
163-
# assertion:
164-
# https://github.com/triton-lang/triton/blob/main/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp#L127
165162
'ignore_pattern': r'(__init__\.py|11.*.py)',
166163
'within_subsection_order': FileNameSortKey,
167164
'reference_url': {

python/tutorials/06-fused-attention.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
===============
44
55
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
6+
67
Credits: OpenAI kernel team
78
89
Extra Credits:
9-
- Original flash attention paper (https://arxiv.org/abs/2205.14135)
10-
- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)
10+
11+
* Original flash attention paper (https://arxiv.org/abs/2205.14135)
12+
* Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)
1113
1214
"""
1315

python/tutorials/07-extern-functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
==============================
44
Triton can invoke a custom function from an external library.
55
In this example, we will use the `libdevice` library to apply `asin` on a tensor.
6-
Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html (CUDA) and/or https://github.com/ROCm/llvm-project/tree/amd-staging/amd/device-libs/ocml/src (HIP) regarding the semantics of all available libdevice functions.
6+
7+
Please refer to `CUDA libdevice-users-guide <https://docs.nvidia.com/cuda/libdevice-users-guide/index.html>`_ and/or `HIP device-lib source code <https://github.com/ROCm/llvm-project/tree/amd-staging/amd/device-libs/ocml/src>`_ regarding the semantics of all available libdevice functions.
8+
79
In `libdevice.py`, we try to aggregate functions with the same computation but different data types together.
810
For example, both `__nv_asin` and `__nv_asinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.
911
Triton automatically selects the correct underlying device function to invoke based on input and output types.

python/tutorials/09-persistent-matmul.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
11
"""
2-
Persistent FP8 Matmul
2+
Persistent Matmul
33
=====================
44
This script demonstrates persistent kernel implementations of matrix multiplication using Triton.
5-
It includes various matmul methods, such as naive, persistent, and TMA (Tensor Memory Accelerator) based approaches, and only supports GPUs with compute capability >= 9.0.
6-
Triton and CuBLAS implementations are benchmarked under different configurations and evaluated using the proton profiler.
5+
Various matmul methods are included, such as naive, persistent, and TMA (Tensor Memory Accelerator) based approaches.
6+
The kernels support both FP16 and FP8 data types but the FP8 implementation is only available on CUDA devices with compute capability >= 9.0.
7+
8+
Triton and cuBLAS implementations are benchmarked under different configurations and evaluated using the proton profiler.
79
Users can pass command-line arguments to specify matrix dimensions and iteration steps flexibly.
10+
11+
.. code-block:: bash
12+
13+
# FP8
14+
python 09-persistent-matmul.py --prec fp8 --K_range 128 1024 --K_step 128
15+
16+
# FP16
17+
python 09-persistent-matmul.py --prec fp16 --K_range 128 1024 --K_step 128
18+
19+
Note that currently this tutorial will fail on devices with a small shared memory size, such as RTX-4090.
820
"""
921

1022
import argparse
@@ -36,12 +48,12 @@ def _matmul_launch_metadata(grid, kernel, args):
3648
ret = {}
3749
M, N, K = args["M"], args["N"], args["K"]
3850
ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]"
39-
ret["flops8"] = 2. * M * N * K
4051
if "c_ptr" in args:
4152
bytes_per_elem = args["c_ptr"].element_size()
4253
else:
4354
bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
44-
ret["bytes"] = bytes_per_elem * (M * K + N * K)
55+
ret[f"flops{bytes_per_elem * 8}"] = 2. * M * N * K
56+
ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N)
4557
return ret
4658

4759

@@ -328,7 +340,7 @@ def matmul_tma_persistent(a, b):
328340
N, K = b.shape
329341
dtype = a.dtype
330342

331-
c = torch.zeros((M, N), device=a.device, dtype=dtype)
343+
c = torch.empty((M, N), device=a.device, dtype=dtype)
332344
desc_a = triton.tools.experimental_descriptor.create_2d_tma_descriptor(a.data_ptr(), M, K,
333345
configs[dtype]["BLOCK_SIZE_M"],
334346
configs[dtype]["BLOCK_SIZE_K"],
@@ -481,7 +493,7 @@ def matmul_device_tma_persistent(a, b, tiles_per_update):
481493
N, K = b.shape
482494
dtype = a.dtype
483495

484-
c = torch.zeros((M, N), device=a.device, dtype=dtype)
496+
c = torch.empty((M, N), device=a.device, dtype=dtype)
485497
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
486498
tma_size = 128
487499
workspace = torch.empty(NUM_SMS * 3 * tma_size, dtype=torch.uint8, device="cuda")
@@ -511,21 +523,20 @@ def cublas_matmul(a, b):
511523
dtype = a.dtype
512524
c = torch.empty((M, N), device=a.device, dtype=dtype)
513525
bytes_per_elem = a.element_size()
514-
flops_str = "flops8" if dtype == torch.float8_e4m3fn else "flops"
515-
with proton.scope(f"cublas M={M}, N={N}, K={K}",
516-
{"bytes": bytes_per_elem * (M * K + N * K), flops_str: 2. * M * N * K}):
526+
flops_str = f"flops{bytes_per_elem * 8}"
527+
with proton.scope(f"cublas [M={M}, N={N}, K={K}]",
528+
{"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}):
517529
cublas.matmul(a, b, c)
518530
return c
519531

520532

521533
def torch_matmul(a, b):
522534
M, K = a.shape
523535
N, K = b.shape
524-
dtype = a.dtype
525536
bytes_per_elem = a.element_size()
526-
flops_str = "flops8" if dtype == torch.float8_e4m3fn else "flops"
527-
with proton.scope(f"torch M={M}, N={N}, K={K}",
528-
{"bytes": bytes_per_elem * (M * K + N * K), flops_str: 2. * M * N * K}):
537+
flops_str = f"flops{bytes_per_elem * 8}"
538+
with proton.scope(f"torch [M={M}, N={N}, K={K}]",
539+
{"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}):
529540
c = torch.matmul(a, b.T)
530541
return c
531542

@@ -558,10 +569,8 @@ def bench(K, dtype, tiles_per_update, reps=10):
558569
for _ in range(reps):
559570
matmul_tma_persistent(a, b)
560571
time.sleep(0.01)
561-
flops_str = "flops8" if dtype == torch.float8_e4m3fn else "flops"
562572
with proton.scope(
563-
f"matmul_kernel_device_tma_persistent M={M}, N={N}, K={K}, tiles_per_update={tiles_per_update:02}",
564-
{"bytes": a.element_size() * (M * K + N * K), flops_str: 2. * M * N * K}):
573+
f"matmul_kernel_device_tma_persistent [M={M}, N={N}, K={K}, tiles_per_update={tiles_per_update:02}]"):
565574
for _ in range(reps):
566575
matmul_device_tma_persistent(a, b, tiles_per_update)
567576
time.sleep(0.01)
@@ -608,6 +617,17 @@ def validate(M, N, K, dtype, tiles_per_update):
608617
print()
609618

610619

620+
def show_profile(precision, profile_name):
621+
import triton.profiler.viewer as proton_viewer
622+
metrics = ["time/ms"]
623+
if precision == 'fp8':
624+
metrics = ["tflop8/s"] + metrics
625+
elif precision == 'fp16':
626+
metrics = ["tflop16/s"] + metrics
627+
file_name = f"{profile_name}.hatchet"
628+
proton_viewer.parse(metrics, file_name, depth=100)
629+
630+
611631
if __name__ == "__main__":
612632
parser = argparse.ArgumentParser()
613633
parser.add_argument("-K", type=int, required=False, default=512)
@@ -642,3 +662,4 @@ def validate(M, N, K, dtype, tiles_per_update):
642662
for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step):
643663
bench(K, dtype, args.tiles_per_update)
644664
proton.finalize()
665+
show_profile(args.prec, "matmul")

third_party/proton/proton/viewer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def filter_frames(gf, include=None, exclude=None, threshold=None, metric=None):
180180
return gf
181181

182182

183-
def parse(metrics, filename, include, exclude, threshold, depth, format):
183+
def parse(metrics, filename, include=None, exclude=None, threshold=None, depth=100, format=None):
184184
with open(filename, "r") as f:
185185
gf, raw_metrics, device_info = get_raw_metrics(f)
186186
gf = format_frames(gf, format)
@@ -190,10 +190,10 @@ def parse(metrics, filename, include, exclude, threshold, depth, format):
190190
# TODO: generalize to support multiple metrics, not just the first one
191191
gf = filter_frames(gf, include, exclude, threshold, metrics[0])
192192
print(gf.tree(metric_column=metrics, expand_name=True, depth=depth, render_header=False))
193-
emitWarnings(gf, metrics)
193+
emit_warnings(gf, metrics)
194194

195195

196-
def emitWarnings(gf, metrics):
196+
def emit_warnings(gf, metrics):
197197
if "bytes (inc)" in metrics:
198198
byte_values = gf.dataframe["bytes (inc)"].values
199199
min_byte_value = np.nanmin(byte_values)
@@ -209,7 +209,6 @@ def show_metrics(file_name):
209209
for raw_metric in raw_metrics:
210210
raw_metric_no_unit = raw_metric.split("(")[0].strip().lower()
211211
print(f"- {raw_metric_no_unit}")
212-
return
213212

214213

215214
def main():

0 commit comments

Comments
 (0)