diff --git a/.github/workflows/triton-benchmarks.yml b/.github/workflows/triton-benchmarks.yml index ceda9d977c..f675fc1c98 100644 --- a/.github/workflows/triton-benchmarks.yml +++ b/.github/workflows/triton-benchmarks.yml @@ -18,6 +18,7 @@ on: options: - ELAPSED_TIME - UPSTREAM_PYTORCH_PROFILER + - PROTON_PROFILER default: UPSTREAM_PYTORCH_PROFILER verify: description: Verify the benchmark results @@ -141,6 +142,14 @@ jobs: python build_report.py $REPORTS/softmax-performance.csv $REPORTS/softmax-xetla-report.csv --benchmark softmax --compiler xetla --param_cols "N" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG python build_report.py $REPORTS/softmax-performance.csv $REPORTS/softmax-onednn-report.csv --benchmark softmax --compiler onednn --param_cols "N" --tflops_col oneDNN-TFlops --hbm_col "oneDNN-GB/s" --tag $TAG + - name: Run Triton Softmax kernel benchmark with Proton + if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'fused_softmax.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'fused_softmax.py') }} + run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH + cd benchmarks/triton_kernels_benchmark + BENCHMARKING_METHOD=PROTON_PROFILER python fused_softmax.py + source ../../scripts/capture-hw-details.sh + - name: Run Triton GEMM kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'gemm_benchmark.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_benchmark.py') }} run: | diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index 0e7d441e12..3d3f1d154a 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -7,6 +7,7 @@ from dataclasses import dataclass, field import itertools import functools +import json import argparse import datetime @@ -18,6 +19,7 @@ import matplotlib.pyplot as plt import torch +import triton.profiler as proton from torch.profiler import profile, ProfilerActivity, record_function from triton.testing import assert_close as triton_assert_close, Benchmark, do_bench as triton_do_bench @@ -210,10 +212,96 @@ def extract_kernels(funcs): return _summarize_statistics(times, quantiles, return_mode) +def do_bench_proton(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu", + sync_submitting=True, time_warmup=True, benchmark_label=None, max_iters=1500): + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param n_warmup: Number of repetitions for warmup + :type n_warmup: int + :param n_repeat: Number of repetitions to collect measurements + :type n_repeat: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float] + """ + + assert return_mode in ["min", "max", "mean", "median"] + + fn() + synchronize() + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 + # doesn't contain any input data before the run + cache_size = 256 * 1024 * 1024 + cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device) + + # Warm-up + if time_warmup: + # Stop either on max iteration number or max time + warmup_time_s = n_warmup / 1000 + assert sync_submitting + start = time.perf_counter() + i = 0 + while i < max_iters and time.perf_counter() - start < warmup_time_s: + fn() + synchronize() + i += 1 + print(f"Stopped warmup after {i} iterations") + else: + for _ in range(n_warmup): + fn() + # To be consistent with the benchmark measurements + if sync_submitting: + synchronize() + + proton.start() + # Benchmark + for idx in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + cache.zero_() + if sync_submitting: + synchronize() + # record time of `fn` + with proton.scope(f"__profile_kernel_of_func{idx}"): + fn() + # Record clocks + synchronize() + proton.finalize() + with open("./proton.hatchet", encoding="utf-8") as f: + data = json.load(f) + + profiling_func_filter = filter( + lambda x: x["frame"]["name"].startswith("__profile_kernel_of_func" + if benchmark_label is None else benchmark_label), data[0]["children"]) + functions = list(profiling_func_filter) + + def extract_kernels(funcs): + return [x["children"][0]["metrics"] for x in funcs] + + kernels = extract_kernels(functions) + # Make the time to the milliseconds. + times = torch.tensor([ks["time (ns)"] * 1e-6 for ks in kernels], dtype=torch.float) + return _summarize_statistics(times, quantiles, return_mode) + + if BENCHMARKING_METHOD == "ELAPSED_TIME": do_bench = do_bench_elapsed_time elif BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER": do_bench = do_bench_upstream_pytorch_profiler +elif BENCHMARKING_METHOD == "PROTON_PROFILER": + do_bench = do_bench_proton else: raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented") diff --git a/third_party/proton/proton/profile.py b/third_party/proton/proton/profile.py index 3e1f54e4aa..c1085917f4 100644 --- a/third_party/proton/proton/profile.py +++ b/third_party/proton/proton/profile.py @@ -83,12 +83,13 @@ def start( Available options are ["tree", "trace"]. Defaults to "tree". backend (str, optional): The backend to use for profiling. - Available options are [None, "cupti", "roctracer", "instrumentation"]. + Available options are [None, "cupti", "xpupti", "roctracer", "instrumentation"]. Defaults to None, which automatically selects the backend matching the current active runtime. mode (Union[str, BaseMode], optional): The "mode" to use for profiling, which is specific to the backend. Can be a string or an instance of BaseMode (or any subclass thereof). Defaults to None. For "cupti", available options are [None, "pcsampling"]. + For "xpupti", available options are [None]. For "roctracer", available options are [None]. For "instrumentation", available options are [None]. Each mode has a set of control knobs following with the mode name.