Skip to content

Commit 3714e9b

Browse files
authored
[CI] Better warmup for flex attention on B580 (#4906)
Flex attention requires more warmup steps on B580. PR adds: 1. Pre-warmup step for flex attention that is called once per run, so it will only run for the first shape config. Experiments show that first config requires more warmup 2. Makes GPU synch consistent between warmup and benchmarking 3. Adds iterations Should resolve #4852 Better warmup should be done after researching in #4911
1 parent 28f406e commit 3714e9b

File tree

4 files changed

+27
-2
lines changed

4 files changed

+27
-2
lines changed

benchmarks/triton_kernels_benchmark/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .benchmark_testing import (
44
assert_close,
55
do_bench,
6+
do_prewarmup,
67
filter_providers,
78
perf_report,
89
Benchmark,
@@ -19,6 +20,7 @@
1920
__all__ = [
2021
"assert_close",
2122
"do_bench",
23+
"do_prewarmup",
2224
"filter_providers",
2325
"perf_report",
2426
"Benchmark",

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
BENCHMARKING_METHOD = os.getenv("BENCHMARKING_METHOD", "UPSTREAM_PYTORCH_PROFILER")
2828
BENCHMARKING_CONFIG = {
2929
"verify": os.getenv("VERIFY", "1") == "1",
30+
"do_prewarmup": os.getenv("PREWARMUP", "1") == "1",
3031
}
3132

3233

@@ -41,6 +42,19 @@ def synchronize():
4142
torch.xpu.synchronize()
4243

4344

45+
def do_prewarmup(fn, min_seconds=5):
46+
"""Looks like some functions require pre-warmup with minimum time to do the compilation.
47+
It has to be done once."""
48+
if not BENCHMARKING_CONFIG["do_prewarmup"]:
49+
return
50+
51+
start = time.time()
52+
while time.time() - start < min_seconds:
53+
fn()
54+
synchronize()
55+
BENCHMARKING_CONFIG["do_prewarmup"] = False
56+
57+
4458
def _summarize_statistics(times, quantiles, return_mode):
4559
if quantiles is not None:
4660
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
@@ -139,6 +153,10 @@ def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_no
139153
# Warm-up
140154
for _ in range(n_warmup):
141155
fn()
156+
# To be consistent with the benchmark measurements
157+
if sync_submitting:
158+
synchronize()
159+
142160
# Benchmark
143161
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.XPU]) as prof:
144162
for _ in range(n_repeat):

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,10 @@ def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provid
165165
triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True)
166166

167167
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-2, rtol=1e-3, err_msg='triton to torch')
168-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
168+
169+
# Needs more warmup on B580 for some reason
170+
benchmark_suit.do_prewarmup(triton_fn)
171+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=200, n_repeat=10, quantiles=quantiles,
169172
device=DEVICE)
170173

171174
elif provider == 'onednn':

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_custom_masks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,9 @@ def benchmark(Z, H, N_CTX, D_HEAD, MASK, MODE, provider):
112112
triton_o = triton_fn()
113113
triton_do = torch.randn_like(triton_o)
114114
triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True)
115-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=5, n_repeat=5, quantiles=quantiles)
115+
# Needs more warmup on B580 for some reason
116+
benchmark_suit.do_prewarmup(triton_fn)
117+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=5, quantiles=quantiles)
116118
# Values checking cannot be implemented for these case as :
117119
# "The operator 'aten::_scaled_dot_product_flash_attention_for_cpu' is not currently implemented for the XPU device"
118120

0 commit comments

Comments
 (0)