Skip to content

Commit bd13369

Browse files
authored
Implement benchmarking using Proton (#5385)
This pull request adds a new profiling mode: `PROTON_PROFILER`, as well as an optional run of the `fused_softmax` benchmark with this mode as a sanity check (collected data is not recorded, but only output to stdout). After further testing, it can be enabled for all benchmarks and made the default benchmarking mode. --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 84d3f38 commit bd13369

File tree

3 files changed

+99
-1
lines changed

3 files changed

+99
-1
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ on:
1818
options:
1919
- ELAPSED_TIME
2020
- UPSTREAM_PYTORCH_PROFILER
21+
- PROTON_PROFILER
2122
default: UPSTREAM_PYTORCH_PROFILER
2223
verify:
2324
description: Verify the benchmark results
@@ -141,6 +142,14 @@ jobs:
141142
python build_report.py $REPORTS/softmax-performance.csv $REPORTS/softmax-xetla-report.csv --benchmark softmax --compiler xetla --param_cols "N" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
142143
python build_report.py $REPORTS/softmax-performance.csv $REPORTS/softmax-onednn-report.csv --benchmark softmax --compiler onednn --param_cols "N" --tflops_col oneDNN-TFlops --hbm_col "oneDNN-GB/s" --tag $TAG
143144
145+
- name: Run Triton Softmax kernel benchmark with Proton
146+
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'fused_softmax.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'fused_softmax.py') }}
147+
run: |
148+
export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH
149+
cd benchmarks/triton_kernels_benchmark
150+
BENCHMARKING_METHOD=PROTON_PROFILER python fused_softmax.py
151+
source ../../scripts/capture-hw-details.sh
152+
144153
- name: Run Triton GEMM kernel benchmark
145154
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'gemm_benchmark.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_benchmark.py') }}
146155
run: |

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from dataclasses import dataclass, field
88
import itertools
99
import functools
10+
import json
1011

1112
import argparse
1213
import datetime
@@ -18,6 +19,7 @@
1819
import matplotlib.pyplot as plt
1920

2021
import torch
22+
import triton.profiler as proton
2123
from torch.profiler import profile, ProfilerActivity, record_function
2224

2325
from triton.testing import assert_close as triton_assert_close, Benchmark, do_bench as triton_do_bench
@@ -210,10 +212,96 @@ def extract_kernels(funcs):
210212
return _summarize_statistics(times, quantiles, return_mode)
211213

212214

215+
def do_bench_proton(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu",
216+
sync_submitting=True, time_warmup=True, benchmark_label=None, max_iters=1500):
217+
"""
218+
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
219+
the 20-th and 80-th performance percentile.
220+
221+
:param fn: Function to benchmark
222+
:type fn: Callable
223+
:param n_warmup: Number of repetitions for warmup
224+
:type n_warmup: int
225+
:param n_repeat: Number of repetitions to collect measurements
226+
:type n_repeat: int
227+
:param grad_to_none: Reset the gradient of the provided tensor to None
228+
:type grad_to_none: torch.tensor, optional
229+
:param quantiles: Performance percentile to return in addition to the median.
230+
:type quantiles: list[float]
231+
"""
232+
233+
assert return_mode in ["min", "max", "mean", "median"]
234+
235+
fn()
236+
synchronize()
237+
238+
# We maintain a buffer of 256 MB that we clear
239+
# before each kernel call to make sure that the L2
240+
# doesn't contain any input data before the run
241+
cache_size = 256 * 1024 * 1024
242+
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device)
243+
244+
# Warm-up
245+
if time_warmup:
246+
# Stop either on max iteration number or max time
247+
warmup_time_s = n_warmup / 1000
248+
assert sync_submitting
249+
start = time.perf_counter()
250+
i = 0
251+
while i < max_iters and time.perf_counter() - start < warmup_time_s:
252+
fn()
253+
synchronize()
254+
i += 1
255+
print(f"Stopped warmup after {i} iterations")
256+
else:
257+
for _ in range(n_warmup):
258+
fn()
259+
# To be consistent with the benchmark measurements
260+
if sync_submitting:
261+
synchronize()
262+
263+
proton.start()
264+
# Benchmark
265+
for idx in range(n_repeat):
266+
# we don't want `fn` to accumulate gradient values
267+
# if it contains a backward pass. So we clear the
268+
# provided gradients
269+
if grad_to_none is not None:
270+
for x in grad_to_none:
271+
x.grad = None
272+
# we clear the L2 cache before each run
273+
cache.zero_()
274+
if sync_submitting:
275+
synchronize()
276+
# record time of `fn`
277+
with proton.scope(f"__profile_kernel_of_func{idx}"):
278+
fn()
279+
# Record clocks
280+
synchronize()
281+
proton.finalize()
282+
with open("./proton.hatchet", encoding="utf-8") as f:
283+
data = json.load(f)
284+
285+
profiling_func_filter = filter(
286+
lambda x: x["frame"]["name"].startswith("__profile_kernel_of_func"
287+
if benchmark_label is None else benchmark_label), data[0]["children"])
288+
functions = list(profiling_func_filter)
289+
290+
def extract_kernels(funcs):
291+
return [x["children"][0]["metrics"] for x in funcs]
292+
293+
kernels = extract_kernels(functions)
294+
# Make the time to the milliseconds.
295+
times = torch.tensor([ks["time (ns)"] * 1e-6 for ks in kernels], dtype=torch.float)
296+
return _summarize_statistics(times, quantiles, return_mode)
297+
298+
213299
if BENCHMARKING_METHOD == "ELAPSED_TIME":
214300
do_bench = do_bench_elapsed_time
215301
elif BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER":
216302
do_bench = do_bench_upstream_pytorch_profiler
303+
elif BENCHMARKING_METHOD == "PROTON_PROFILER":
304+
do_bench = do_bench_proton
217305
else:
218306
raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented")
219307

third_party/proton/proton/profile.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,13 @@ def start(
8383
Available options are ["tree", "trace"].
8484
Defaults to "tree".
8585
backend (str, optional): The backend to use for profiling.
86-
Available options are [None, "cupti", "roctracer", "instrumentation"].
86+
Available options are [None, "cupti", "xpupti", "roctracer", "instrumentation"].
8787
Defaults to None, which automatically selects the backend matching the current active runtime.
8888
mode (Union[str, BaseMode], optional): The "mode" to use for profiling, which is specific to the backend.
8989
Can be a string or an instance of BaseMode (or any subclass thereof).
9090
Defaults to None.
9191
For "cupti", available options are [None, "pcsampling"].
92+
For "xpupti", available options are [None].
9293
For "roctracer", available options are [None].
9394
For "instrumentation", available options are [None].
9495
Each mode has a set of control knobs following with the mode name.

0 commit comments

Comments
 (0)