Skip to content
2 changes: 2 additions & 0 deletions benchmarks/triton_kernels_benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .benchmark_testing import (
assert_close,
do_bench,
do_prewarmup,
filter_providers,
perf_report,
Benchmark,
Expand All @@ -19,6 +20,7 @@
__all__ = [
"assert_close",
"do_bench",
"do_prewarmup",
"filter_providers",
"perf_report",
"Benchmark",
Expand Down
23 changes: 18 additions & 5 deletions benchmarks/triton_kernels_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
BENCHMARKING_METHOD = os.getenv("BENCHMARKING_METHOD", "UPSTREAM_PYTORCH_PROFILER")
BENCHMARKING_CONFIG = {
"verify": os.getenv("VERIFY", "1") == "1",
"do_prewarmup": os.getenv("PREWARMUP", "1") == "1",
}


Expand All @@ -41,6 +42,19 @@ def synchronize():
torch.xpu.synchronize()


def do_prewarmup(fn, min_seconds=5):
"""Looks like some functions require pre-warmup with minimum time to do the compilation.
It has to be done once."""
if not BENCHMARKING_CONFIG["do_prewarmup"]:
return

start = time.time()
while time.time() - start < min_seconds:
fn()
synchronize()
BENCHMARKING_CONFIG["do_prewarmup"] = False


def _summarize_statistics(times, quantiles, return_mode):
if quantiles is not None:
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
Expand Down Expand Up @@ -127,18 +141,17 @@ def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_no

assert return_mode in ["min", "max", "mean", "median"]

fn()
synchronize()
# Warm-up
for _ in range(n_warmup + 1):
fn()
synchronize()

# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
cache_size = 256 * 1024 * 1024
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device)

# Warm-up
for _ in range(n_warmup):
fn()
# Benchmark
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.XPU]) as prof:
for _ in range(n_repeat):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,13 @@ def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provid
triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True)

benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-2, rtol=1e-3, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
device=DEVICE)
# Need more warmups on B580 due to the torch.compile

is_bmg = any(name in torch.xpu.get_device_name().lower() for name in ('b570', 'b580'))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can keep it simple and increase across platforms, no need to check if it is bmg.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if is_bmg:
benchmark_suit.do_prewarmup(triton_fn)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=200 if is_bmg else 10, n_repeat=10,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need both prewarmup and increasing n_warmup?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, first warmup across all shapes takes a lot of time. Just setting n_warmup to 200 is not always enough

quantiles=quantiles, device=DEVICE)

elif provider == 'onednn':
# OneDNN only supports MHA.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, MASK, MODE, provider):
triton_o = triton_fn()
triton_do = torch.randn_like(triton_o)
triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=5, n_repeat=5, quantiles=quantiles)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=5, quantiles=quantiles)
# Values checking cannot be implemented for these case as :
# "The operator 'aten::_scaled_dot_product_flash_attention_for_cpu' is not currently implemented for the XPU device"

Expand Down
Loading