|
7 | 7 | from dataclasses import dataclass, field |
8 | 8 | import itertools |
9 | 9 | import functools |
| 10 | +import json |
10 | 11 |
|
11 | 12 | import argparse |
12 | 13 | import datetime |
|
18 | 19 | import matplotlib.pyplot as plt |
19 | 20 |
|
20 | 21 | import torch |
| 22 | +import triton.profiler as proton |
21 | 23 | from torch.profiler import profile, ProfilerActivity, record_function |
22 | 24 |
|
23 | 25 | from triton.testing import assert_close as triton_assert_close, Benchmark, do_bench as triton_do_bench |
@@ -210,10 +212,96 @@ def extract_kernels(funcs): |
210 | 212 | return _summarize_statistics(times, quantiles, return_mode) |
211 | 213 |
|
212 | 214 |
|
| 215 | +def do_bench_proton(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu", |
| 216 | + sync_submitting=True, time_warmup=True, benchmark_label=None, max_iters=1500): |
| 217 | + """ |
| 218 | + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with |
| 219 | + the 20-th and 80-th performance percentile. |
| 220 | +
|
| 221 | + :param fn: Function to benchmark |
| 222 | + :type fn: Callable |
| 223 | + :param n_warmup: Number of repetitions for warmup |
| 224 | + :type n_warmup: int |
| 225 | + :param n_repeat: Number of repetitions to collect measurements |
| 226 | + :type n_repeat: int |
| 227 | + :param grad_to_none: Reset the gradient of the provided tensor to None |
| 228 | + :type grad_to_none: torch.tensor, optional |
| 229 | + :param quantiles: Performance percentile to return in addition to the median. |
| 230 | + :type quantiles: list[float] |
| 231 | + """ |
| 232 | + |
| 233 | + assert return_mode in ["min", "max", "mean", "median"] |
| 234 | + |
| 235 | + fn() |
| 236 | + synchronize() |
| 237 | + |
| 238 | + # We maintain a buffer of 256 MB that we clear |
| 239 | + # before each kernel call to make sure that the L2 |
| 240 | + # doesn't contain any input data before the run |
| 241 | + cache_size = 256 * 1024 * 1024 |
| 242 | + cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device) |
| 243 | + |
| 244 | + # Warm-up |
| 245 | + if time_warmup: |
| 246 | + # Stop either on max iteration number or max time |
| 247 | + warmup_time_s = n_warmup / 1000 |
| 248 | + assert sync_submitting |
| 249 | + start = time.perf_counter() |
| 250 | + i = 0 |
| 251 | + while i < max_iters and time.perf_counter() - start < warmup_time_s: |
| 252 | + fn() |
| 253 | + synchronize() |
| 254 | + i += 1 |
| 255 | + print(f"Stopped warmup after {i} iterations") |
| 256 | + else: |
| 257 | + for _ in range(n_warmup): |
| 258 | + fn() |
| 259 | + # To be consistent with the benchmark measurements |
| 260 | + if sync_submitting: |
| 261 | + synchronize() |
| 262 | + |
| 263 | + proton.start() |
| 264 | + # Benchmark |
| 265 | + for idx in range(n_repeat): |
| 266 | + # we don't want `fn` to accumulate gradient values |
| 267 | + # if it contains a backward pass. So we clear the |
| 268 | + # provided gradients |
| 269 | + if grad_to_none is not None: |
| 270 | + for x in grad_to_none: |
| 271 | + x.grad = None |
| 272 | + # we clear the L2 cache before each run |
| 273 | + cache.zero_() |
| 274 | + if sync_submitting: |
| 275 | + synchronize() |
| 276 | + # record time of `fn` |
| 277 | + with proton.scope(f"__profile_kernel_of_func{idx}"): |
| 278 | + fn() |
| 279 | + # Record clocks |
| 280 | + synchronize() |
| 281 | + proton.finalize() |
| 282 | + with open("./proton.hatchet", encoding="utf-8") as f: |
| 283 | + data = json.load(f) |
| 284 | + |
| 285 | + profiling_func_filter = filter( |
| 286 | + lambda x: x["frame"]["name"].startswith("__profile_kernel_of_func" |
| 287 | + if benchmark_label is None else benchmark_label), data[0]["children"]) |
| 288 | + functions = list(profiling_func_filter) |
| 289 | + |
| 290 | + def extract_kernels(funcs): |
| 291 | + return [x["children"][0]["metrics"] for x in funcs] |
| 292 | + |
| 293 | + kernels = extract_kernels(functions) |
| 294 | + # Make the time to the milliseconds. |
| 295 | + times = torch.tensor([ks["time (ns)"] * 1e-6 for ks in kernels], dtype=torch.float) |
| 296 | + return _summarize_statistics(times, quantiles, return_mode) |
| 297 | + |
| 298 | + |
213 | 299 | if BENCHMARKING_METHOD == "ELAPSED_TIME": |
214 | 300 | do_bench = do_bench_elapsed_time |
215 | 301 | elif BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER": |
216 | 302 | do_bench = do_bench_upstream_pytorch_profiler |
| 303 | +elif BENCHMARKING_METHOD == "PROTON_PROFILER": |
| 304 | + do_bench = do_bench_proton |
217 | 305 | else: |
218 | 306 | raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented") |
219 | 307 |
|
|
0 commit comments