diff --git a/benchmarks/triton_kernels_benchmark/__init__.py b/benchmarks/triton_kernels_benchmark/__init__.py index 818936de21..27ed7ff9f5 100644 --- a/benchmarks/triton_kernels_benchmark/__init__.py +++ b/benchmarks/triton_kernels_benchmark/__init__.py @@ -3,6 +3,7 @@ from .benchmark_testing import ( assert_close, do_bench, + do_prewarmup, filter_providers, perf_report, Benchmark, @@ -19,6 +20,7 @@ __all__ = [ "assert_close", "do_bench", + "do_prewarmup", "filter_providers", "perf_report", "Benchmark", diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index 1110ee1161..6e81af3158 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -27,6 +27,7 @@ BENCHMARKING_METHOD = os.getenv("BENCHMARKING_METHOD", "UPSTREAM_PYTORCH_PROFILER") BENCHMARKING_CONFIG = { "verify": os.getenv("VERIFY", "1") == "1", + "do_prewarmup": os.getenv("PREWARMUP", "1") == "1", } @@ -41,6 +42,19 @@ def synchronize(): torch.xpu.synchronize() +def do_prewarmup(fn, min_seconds=5): + """Looks like some functions require pre-warmup with minimum time to do the compilation. + It has to be done once.""" + if not BENCHMARKING_CONFIG["do_prewarmup"]: + return + + start = time.time() + while time.time() - start < min_seconds: + fn() + synchronize() + BENCHMARKING_CONFIG["do_prewarmup"] = False + + def _summarize_statistics(times, quantiles, return_mode): if quantiles is not None: ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() @@ -127,8 +141,10 @@ def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_no assert return_mode in ["min", "max", "mean", "median"] - fn() - synchronize() + # Warm-up + for _ in range(n_warmup + 1): + fn() + synchronize() # We maintain a buffer of 256 MB that we clear # before each kernel call to make sure that the L2 @@ -136,9 +152,6 @@ def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_no cache_size = 256 * 1024 * 1024 cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device) - # Warm-up - for _ in range(n_warmup): - fn() # Benchmark with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.XPU]) as prof: for _ in range(n_repeat): diff --git a/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py b/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py index 1728dd2dd6..462d47dadc 100644 --- a/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py +++ b/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py @@ -165,8 +165,13 @@ def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provid triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True) benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-2, rtol=1e-3, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles, - device=DEVICE) + # Need more warmups on B580 due to the torch.compile + + is_bmg = any(name in torch.xpu.get_device_name().lower() for name in ('b570', 'b580')) + if is_bmg: + benchmark_suit.do_prewarmup(triton_fn) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=200 if is_bmg else 10, n_repeat=10, + quantiles=quantiles, device=DEVICE) elif provider == 'onednn': # OneDNN only supports MHA. diff --git a/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_custom_masks.py b/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_custom_masks.py index e055865265..11c7432dfa 100644 --- a/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_custom_masks.py +++ b/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_custom_masks.py @@ -112,7 +112,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, MASK, MODE, provider): triton_o = triton_fn() triton_do = torch.randn_like(triton_o) triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True) - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=5, n_repeat=5, quantiles=quantiles) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=5, quantiles=quantiles) # Values checking cannot be implemented for these case as : # "The operator 'aten::_scaled_dot_product_flash_attention_for_cpu' is not currently implemented for the XPU device"