Skip to content

Commit d3a8eb0

Browse files
authored
Use iteration count instead of time for parameters warmup and rep of do_bench* functions for benchmarks (#2256)
Closes #2255 Partially reusing the changes, which were removed in #2142 (namely the part related to using iteration count instead of time) solves the problem of not having enough data for "CV" column. --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 7fc0d2b commit d3a8eb0

9 files changed

+67
-68
lines changed

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 43 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,18 @@ def _summarize_statistics(times, quantiles, return_mode):
3636
return getattr(torch, return_mode)(times).item()
3737

3838

39-
def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu",
39+
def do_bench_ipex(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu",
4040
sync_submitting=True, kernel_name=None): # pylint: disable=unused-argument
4141
"""
4242
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
4343
the 20-th and 80-th performance percentile.
4444
4545
:param fn: Function to benchmark
4646
:type fn: Callable
47-
:param warmup: Warmup time (in ms)
48-
:type warmup: int
49-
:param rep: Repetition time (in ms)
50-
:type rep: int
47+
:param n_warmup: Number of repetitions for warmup
48+
:type n_warmup: int
49+
:param n_repeat: Number of repetitions to collect measurements
50+
:type n_repeat: int
5151
:param grad_to_none: Reset the gradient of the provided tensor to None
5252
:type grad_to_none: torch.tensor, optional
5353
:param quantiles: Performance percentile to return in addition to the median.
@@ -69,20 +69,6 @@ def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, ret
6969
cache_size = 256 * 1024 * 1024
7070
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device)
7171

72-
# Estimate the runtime of the function
73-
start_event = torch.xpu.Event(enable_timing=True)
74-
end_event = torch.xpu.Event(enable_timing=True)
75-
start_event.record()
76-
for _ in range(5):
77-
cache.zero_()
78-
fn()
79-
end_event.record()
80-
synchronize()
81-
estimate_ms = start_event.elapsed_time(end_event) / 5
82-
83-
# compute number of warmup and repeat
84-
n_warmup = max(1, int(warmup / estimate_ms))
85-
n_repeat = max(1, int(rep / estimate_ms))
8672
# Warm-up
8773
for _ in range(n_warmup):
8874
fn()
@@ -121,18 +107,18 @@ def extract_kernels(funcs):
121107
return _summarize_statistics(times, quantiles, return_mode)
122108

123109

124-
def do_bench_elapsed_time(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu",
125-
kernel_name=None): # pylint: disable=unused-argument
110+
def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean",
111+
device="xpu", kernel_name=None): # pylint: disable=unused-argument
126112
"""
127113
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
128114
the 20-th and 80-th performance percentile.
129115
130116
:param fn: Function to benchmark
131117
:type fn: Callable
132-
:param warmup: Warmup time (in ms)
133-
:type warmup: int
134-
:param rep: Repetition time (in ms)
135-
:type rep: int
118+
:param n_warmup: Number of repetitions for warmup
119+
:type n_warmup: int
120+
:param n_repeat: Number of repetitions to collect measurements
121+
:type n_repeat: int
136122
:param grad_to_none: Reset the gradient of the provided tensor to None
137123
:type grad_to_none: torch.tensor, optional
138124
:param quantiles: Performance percentile to return in addition to the median.
@@ -142,24 +128,49 @@ def do_bench_elapsed_time(fn, warmup=25, rep=100, grad_to_none=None, quantiles=N
142128
import torch
143129
from triton.testing import do_bench as triton_do_bench
144130

145-
times = triton_do_bench(fn, warmup=warmup, rep=rep, grad_to_none=grad_to_none, return_mode="all",
131+
# We maintain a buffer of 256 MB that we clear
132+
# before each kernel call to make sure that the L2
133+
# doesn't contain any input data before the run
134+
cache_size = 256 * 1024 * 1024
135+
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device)
136+
137+
# Estimate the runtime of the function
138+
start_event = torch.xpu.Event(enable_timing=True)
139+
end_event = torch.xpu.Event(enable_timing=True)
140+
start_event.record()
141+
for _ in range(5):
142+
cache.zero_()
143+
fn()
144+
end_event.record()
145+
synchronize()
146+
estimate_ms = start_event.elapsed_time(end_event) / 5
147+
148+
# The cache is also maintained in `triton_do_bench` function,
149+
# there is no need to duplicate the amount of memory used.
150+
del cache
151+
152+
# compute warmup and repeat times
153+
warmup_time = n_warmup * estimate_ms
154+
rep_time = n_repeat * estimate_ms
155+
156+
times = triton_do_bench(fn, warmup=warmup_time, rep=rep_time, grad_to_none=grad_to_none, return_mode="all",
146157
device_type=device)
147158
times = torch.tensor(times, dtype=torch.float)
148159
return _summarize_statistics(times, quantiles, return_mode)
149160

150161

151-
def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean",
152-
device="xpu", sync_submitting=True, kernel_name=None):
162+
def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None,
163+
return_mode="mean", device="xpu", sync_submitting=True, kernel_name=None):
153164
"""
154165
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
155166
the 20-th and 80-th performance percentile.
156167
157168
:param fn: Function to benchmark
158169
:type fn: Callable
159-
:param warmup: Warmup time (in ms)
160-
:type warmup: int
161-
:param rep: Repetition time (in ms)
162-
:type rep: int
170+
:param n_warmup: Number of repetitions for warmup
171+
:type n_warmup: int
172+
:param n_repeat: Number of repetitions to collect measurements
173+
:type n_repeat: int
163174
:param grad_to_none: Reset the gradient of the provided tensor to None
164175
:type grad_to_none: torch.tensor, optional
165176
:param quantiles: Performance percentile to return in addition to the median.
@@ -179,20 +190,6 @@ def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None
179190
cache_size = 256 * 1024 * 1024
180191
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device)
181192

182-
# Estimate the runtime of the function
183-
start_event = torch.xpu.Event(enable_timing=True)
184-
end_event = torch.xpu.Event(enable_timing=True)
185-
start_event.record()
186-
for _ in range(5):
187-
cache.zero_()
188-
fn()
189-
end_event.record()
190-
synchronize()
191-
estimate_ms = start_event.elapsed_time(end_event) / 5
192-
193-
# compute number of warmup and repeat
194-
n_warmup = max(1, int(warmup / estimate_ms))
195-
n_repeat = max(1, int(rep / estimate_ms))
196193
# Warm-up
197194
for _ in range(n_warmup):
198195
fn()

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
242242
if provider == 'onednn':
243243
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(
244244
lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=
245-
CAUSAL, scale=sm_scale), warmup=10, rep=10,
245+
CAUSAL, scale=sm_scale), m_warmup=10, n_repeat=10,
246246
quantiles=quantiles)
247247

248248
elif provider == 'triton':
@@ -256,7 +256,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
256256
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
257257
atol = 1e-1 if N_CTX == 16384 else 1e-2
258258
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
259-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
259+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
260260
kernel_name='_attn_fwd')
261261

262262
elif provider == 'xetla':
@@ -272,7 +272,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
272272
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
273273

274274
xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
275-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
275+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
276276
kernel_name='gpu::xetla::fmha::FmhaForwardKernel<')
277277

278278
else:

benchmarks/triton_kernels_benchmark/fused_softmax.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,18 @@ def benchmark(M, N, provider):
125125
quantiles = [0.5, 0.0, 1.0]
126126
if provider == "torch-native":
127127
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles,
128-
warmup=10, rep=10)
128+
n_warmup=10, n_repeat=10)
129129
if provider == "triton":
130130
out = torch.empty_like(x, device="xpu")
131131
triton_fn = lambda: softmax(x, out)
132132
torch_fn = lambda: torch.softmax(x, axis=-1)
133133
benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch")
134-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10,
134+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, n_warmup=10, n_repeat=10,
135135
kernel_name="softmax_kernel")
136136

137137
elif provider == "torch-jit":
138-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, warmup=10,
139-
rep=10)
138+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles,
139+
n_warmup=10, n_repeat=10)
140140

141141
elif provider == "xetla":
142142
name = f"softmax_shape_{M}_{N}"
@@ -154,7 +154,7 @@ def benchmark(M, N, provider):
154154
"softmax_shape_4096_16384": "mat1_4096x16k_bf16_cfg0",
155155
"softmax_shape_4096_32768": "mat1_4096x32k_bf16_cfg0",
156156
}
157-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10,
157+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, n_warmup=10, n_repeat=10,
158158
kernel_name=kernels_name[name])
159159

160160
else:

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def benchmark(B, M, N, K, provider):
283283
if BENCHMARKING_METHOD == 'PYTORCH_LEGACY_PROFILER_USING_IPEX':
284284
# Legacy profiler shows ~6000TFLOPS GeoMean for onednn measurements, so use more reliable method
285285
do_bench = do_bench_elapsed_time
286-
_, min_ms, max_ms, mean_ms, cv = do_bench(lambda: torch.matmul(torch_a, torch_b), warmup=10, rep=10,
286+
_, min_ms, max_ms, mean_ms, cv = do_bench(lambda: torch.matmul(torch_a, torch_b), n_warmup=10, n_repeat=10,
287287
quantiles=quantiles, kernel_name='gemm_kernel')
288288
elif provider == 'triton':
289289
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
@@ -296,7 +296,8 @@ def benchmark(B, M, N, K, provider):
296296
torch_fn = lambda: torch.matmul(torch_a, torch_b).to(torch.float32)
297297
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
298298
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
299-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
299+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
300+
quantiles=quantiles,
300301
kernel_name='matmul_kernel_with_block_pointers')
301302
elif provider == 'xetla':
302303
if B == 1:
@@ -340,8 +341,8 @@ def benchmark(B, M, N, K, provider):
340341
}
341342

342343
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
343-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
344-
kernel_name=kernels_name[name])
344+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
345+
quantiles=quantiles, kernel_name=kernels_name[name])
345346
else:
346347
raise NotImplementedError(f'Unsupported provider {provider}')
347348

benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,8 @@ def benchmark(B, M, N, K, provider):
275275
torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d
276276
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
277277
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
278-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
279-
kernel_name=kernel_name)
278+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
279+
quantiles=quantiles, kernel_name=kernel_name)
280280
else:
281281
raise NotImplementedError(f'Unsupported provider {provider}')
282282

benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,8 @@ def benchmark(B, M, N, K, provider):
277277
torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32))
278278
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
279279
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
280-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
281-
kernel_name=kernel_name)
280+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
281+
quantiles=quantiles, kernel_name=kernel_name)
282282
else:
283283
raise NotImplementedError(f'Unsupported provider {provider}')
284284

benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,8 @@ def benchmark(B, M, N, K, provider):
265265
torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32)
266266
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
267267
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
268-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
269-
kernel_name=kernel_name)
268+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
269+
quantiles=quantiles, kernel_name=kernel_name)
270270
else:
271271
raise NotImplementedError(f'Unsupported provider {provider}')
272272

benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,15 +148,15 @@ def benchmark(M, N, K, provider):
148148
quantiles = [0.5, 0.0, 1.0]
149149

150150
if provider == 'onednn':
151-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
151+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10,
152152
quantiles=quantiles)
153153
elif provider == 'triton':
154154
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
155155
triton_fn = lambda: matmul(a, b, c)
156156
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
157157
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
158158
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
159-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
159+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
160160
kernel_name='_kernel')
161161
else:
162162
raise NotImplementedError(f'Unsupported provider {provider}')

benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,14 +271,15 @@ def benchmark(M, N, K, provider):
271271
quantiles = [0.5, 0.0, 1.0]
272272

273273
if provider == 'onednn':
274-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
274+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10,
275275
quantiles=quantiles)
276276
elif provider == 'triton':
277277
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
278278
triton_fn = lambda: matmul(a, b, c)
279279
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
280280
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch')
281-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
281+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
282+
quantiles=quantiles,
282283
kernel_name=['first_wave', 'full_tiles'])
283284
else:
284285
raise NotImplementedError(f'Unsupported provider {provider}')

0 commit comments

Comments
 (0)