Skip to content
2 changes: 2 additions & 0 deletions benchmarks/triton_kernels_benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .benchmark_testing import (
assert_close,
do_bench,
do_prewarmup,
filter_providers,
perf_report,
Benchmark,
Expand All @@ -19,6 +20,7 @@
__all__ = [
"assert_close",
"do_bench",
"do_prewarmup",
"filter_providers",
"perf_report",
"Benchmark",
Expand Down
18 changes: 18 additions & 0 deletions benchmarks/triton_kernels_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
BENCHMARKING_METHOD = os.getenv("BENCHMARKING_METHOD", "UPSTREAM_PYTORCH_PROFILER")
BENCHMARKING_CONFIG = {
"verify": os.getenv("VERIFY", "1") == "1",
"do_prewarmup": os.getenv("PREWARMUP", "1") == "1",
}


Expand All @@ -41,6 +42,19 @@ def synchronize():
torch.xpu.synchronize()


def do_prewarmup(fn, min_seconds=5):
"""Looks like some functions require pre-warmup with minimum time to do the compilation.
It has to be done once."""
if not BENCHMARKING_CONFIG["do_prewarmup"]:
return

start = time.time()
while time.time() - start < min_seconds:
fn()
synchronize()
BENCHMARKING_CONFIG["do_prewarmup"] = False


def _summarize_statistics(times, quantiles, return_mode):
if quantiles is not None:
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
Expand Down Expand Up @@ -139,6 +153,10 @@ def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_no
# Warm-up
for _ in range(n_warmup):
fn()
# To be consistent with the benchmark measurements
if sync_submitting:
synchronize()

# Benchmark
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.XPU]) as prof:
for _ in range(n_repeat):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provid
triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True)

benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-2, rtol=1e-3, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,

# Needs more warmup on B580 for some reason
benchmark_suit.do_prewarmup(triton_fn)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=200, n_repeat=10, quantiles=quantiles,
device=DEVICE)

elif provider == 'onednn':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def benchmark(Z, H, N_CTX, D_HEAD, MASK, MODE, provider):
triton_o = triton_fn()
triton_do = torch.randn_like(triton_o)
triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=5, n_repeat=5, quantiles=quantiles)
# Needs more warmup on B580 for some reason
benchmark_suit.do_prewarmup(triton_fn)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=5, quantiles=quantiles)
# Values checking cannot be implemented for these case as :
# "The operator 'aten::_scaled_dot_product_flash_attention_for_cpu' is not currently implemented for the XPU device"

Expand Down
Loading