Skip to content

Commit e67ac5d

Browse files
authored
[BENCHMARKS] fix typo: benchmark_suit -> benchmark_suite (#5178)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent d3dfcfe commit e67ac5d

8 files changed

+60
-60
lines changed

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch._inductor.kernel.flex.flex_attention as flex_attn
1414
from torch._inductor.template_heuristics.triton import FlexConfig, FlexDecodeConfig
1515

16-
import triton_kernels_benchmark as benchmark_suit
16+
import triton_kernels_benchmark as benchmark_suite
1717
import triton
1818

1919
DEVICE = triton.runtime.driver.active.get_active_torch_device()
@@ -77,8 +77,8 @@ def causal_mask(_, __, q_idx, kv_idx):
7777

7878
# Kernel profiling for Backward mode is not working as expected:
7979
# For details: https://github.com/pytorch/pytorch/issues/144778
80-
@benchmark_suit.perf_report(
81-
benchmark_suit.Benchmark(
80+
@benchmark_suite.perf_report(
81+
benchmark_suite.Benchmark(
8282
x_names=['Z', 'H_q', 'H_kv', 'N_CTX_q', 'N_CTX_kv', 'D_HEAD_qk', 'D_HEAD_v', 'MODE'],
8383
x_vals=
8484
# Multi-head attention. H_q equals H_kv
@@ -158,8 +158,8 @@ def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provid
158158
mean = float('nan')
159159
cv = float('nan')
160160
else:
161-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(torch_fn, n_warmup=n_warmup, n_repeat=10,
162-
quantiles=quantiles, device=DEVICE)
161+
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(torch_fn, n_warmup=n_warmup, n_repeat=10,
162+
quantiles=quantiles, device=DEVICE)
163163

164164
elif provider == 'triton':
165165
kernel_options = {'BLOCKS_ARE_CONTIGUOUS': True, 'USE_TMA': True}
@@ -176,14 +176,14 @@ def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provid
176176

177177
tensor_names = ['out', 'grad_query', 'grad_key', 'grad_value']
178178
for eager, compiled, name in zip(eager_tensors, compiled_tensors, tensor_names):
179-
benchmark_suit.assert_close(lambda: eager, lambda: compiled, atol=1e-2, rtol=1e-3, # pylint: disable=cell-var-from-loop
180-
err_msg=f'Error comparing {name} between triton and torch')
179+
benchmark_suite.assert_close(lambda: eager, lambda: compiled, atol=1e-2, rtol=1e-3, # pylint: disable=cell-var-from-loop
180+
err_msg=f'Error comparing {name} between triton and torch')
181181

182182
triton_fn = lambda: torch.autograd.grad((triton_o, ), (q, k, v), backwards_grad, retain_graph=True)
183183
else:
184-
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-2, rtol=1e-3, err_msg='triton to torch')
184+
benchmark_suite.assert_close(triton_fn, torch_fn, atol=1e-2, rtol=1e-3, err_msg='triton to torch')
185185

186-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(
186+
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
187187
triton_fn, n_warmup=n_warmup, n_repeat=10, quantiles=quantiles, device=DEVICE, grad_to_none=(q, k, v),
188188
benchmark_label=None if MODE == 'fwd' else 'CompiledFunctionBackward')
189189

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_custom_masks.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111
import torch.nn.functional as F
1212

13-
import triton_kernels_benchmark as benchmark_suit
13+
import triton_kernels_benchmark as benchmark_suite
1414

1515
torch._dynamo.config.recompile_limit = 100 # pylint: disable=protected-access
1616

@@ -57,8 +57,8 @@ def alibi_functional(score, _, h, q_idx, kv_idx):
5757

5858
# Kernel profiling for Backward mode is not working as expected:
5959
# For details: https://github.com/pytorch/pytorch/issues/144778
60-
@benchmark_suit.perf_report(
61-
benchmark_suit.Benchmark(
60+
@benchmark_suite.perf_report(
61+
benchmark_suite.Benchmark(
6262
x_names=['Z', 'H', 'N_CTX', 'D_HEAD', 'MASK', 'MODE'],
6363
x_vals=[[z, h, 16384 // z, dhead, mask, mode]
6464
for z in [4, 8, 16, 32]
@@ -114,8 +114,8 @@ def benchmark(Z, H, N_CTX, D_HEAD, MASK, MODE, provider):
114114
triton_o = triton_fn()
115115
triton_do = torch.randn_like(triton_o)
116116
triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True)
117-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=n_warmup, n_repeat=10,
118-
quantiles=quantiles)
117+
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(triton_fn, n_warmup=n_warmup, n_repeat=10,
118+
quantiles=quantiles)
119119
# Values checking cannot be implemented for these case as :
120120
# "The operator 'aten::_scaled_dot_product_flash_attention_for_cpu' is not currently implemented for the XPU device"
121121

@@ -125,8 +125,8 @@ def benchmark(Z, H, N_CTX, D_HEAD, MASK, MODE, provider):
125125
xformers_o = xformers_fn()
126126
xformers_do = torch.randn_like(xformers_o)
127127
xformers_fn = lambda: xformers_o.backward(xformers_do, retain_graph=True)
128-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xformers_fn, n_warmup=n_warmup, n_repeat=10,
129-
quantiles=quantiles)
128+
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(xformers_fn, n_warmup=n_warmup, n_repeat=10,
129+
quantiles=quantiles)
130130

131131
else:
132132
raise NotImplementedError(f'Unsupported provider {provider}')

benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import triton
1212
import triton.language as tl
1313

14-
import triton_kernels_benchmark as benchmark_suit
14+
import triton_kernels_benchmark as benchmark_suite
1515
import psutil
1616

1717
INT8_ONLY_OPTION = os.getenv('INT8_ONLY', '0') == '1'
@@ -295,8 +295,8 @@ def is_enough_memory(x_val):
295295

296296

297297
# Benchmark Performance
298-
@benchmark_suit.perf_report(
299-
benchmark_suit.Benchmark(
298+
@benchmark_suite.perf_report(
299+
benchmark_suite.Benchmark(
300300
# argument names to use as an x-axis for the plot
301301
x_names=['B', 'M', 'K', 'N', 'dtype'],
302302
# different possible values for `x_name`
@@ -335,8 +335,8 @@ def benchmark(B, M, N, K, dtype, provider):
335335
quantiles = [0.5, 0.0, 1.0]
336336

337337
if provider == 'onednn':
338-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b) + d, n_warmup=n_warmup,
339-
n_repeat=10, quantiles=quantiles)
338+
_, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench(lambda: torch.matmul(a, b) + d, n_warmup=n_warmup,
339+
n_repeat=10, quantiles=quantiles)
340340
elif provider == 'triton':
341341
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
342342
if len(a.shape) == 3:
@@ -355,9 +355,9 @@ def benchmark(B, M, N, K, dtype, provider):
355355
if dtype.is_floating_point or [B, M, N, K] in [[1, 1024, 1024, 1024], [1, 2048, 2048, 2048],
356356
[1, 512, 8192, 32768], [4, 32768, 4096, 128]]:
357357
# torch int8 matmul on GPU is not supported. only check a few int8 shapes to reduce runtime
358-
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch')
359-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=n_warmup, n_repeat=10,
360-
quantiles=quantiles)
358+
benchmark_suite.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch')
359+
_, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench(triton_fn, n_warmup=n_warmup, n_repeat=10,
360+
quantiles=quantiles)
361361
else:
362362
raise NotImplementedError(f'Unsupported provider {provider}')
363363

benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import triton
1313
import triton.language as tl
1414

15-
import triton_kernels_benchmark as benchmark_suit
15+
import triton_kernels_benchmark as benchmark_suite
1616

1717
kAlpha = tl.constexpr(math.sqrt(2.0 / math.pi))
1818

@@ -253,8 +253,8 @@ def is_enough_memory(x_val):
253253

254254

255255
# Benchmark Performance
256-
@benchmark_suit.perf_report(
257-
benchmark_suit.Benchmark(
256+
@benchmark_suite.perf_report(
257+
benchmark_suite.Benchmark(
258258
# argument names to use as an x-axis for the plot
259259
x_names=['B', 'M', 'K', 'N'],
260260
# different possible values for `x_name`
@@ -294,9 +294,9 @@ def benchmark(B, M, N, K, provider):
294294
triton_fn = lambda: matmul(a, b, c)
295295
torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32))
296296
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
297-
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch')
298-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=n_warmup, n_repeat=10,
299-
quantiles=quantiles)
297+
benchmark_suite.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch')
298+
_, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench(triton_fn, n_warmup=n_warmup, n_repeat=10,
299+
quantiles=quantiles)
300300
else:
301301
raise NotImplementedError(f'Unsupported provider {provider}')
302302

benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import triton
1212
import triton.language as tl
1313

14-
import triton_kernels_benchmark as benchmark_suit
14+
import triton_kernels_benchmark as benchmark_suite
1515

1616

1717
@triton.autotune(
@@ -241,8 +241,8 @@ def is_enough_memory(x_val):
241241

242242

243243
# Benchmark Performance
244-
@benchmark_suit.perf_report(
245-
benchmark_suit.Benchmark(
244+
@benchmark_suite.perf_report(
245+
benchmark_suite.Benchmark(
246246
# argument names to use as an x-axis for the plot
247247
x_names=['B', 'M', 'K', 'N'],
248248
# different possible values for `x_name`
@@ -286,9 +286,9 @@ def benchmark(B, M, N, K, provider):
286286
triton_fn = lambda: matmul(a, b, c)
287287
torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32)
288288
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
289-
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch')
290-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=n_warmup, n_repeat=10,
291-
quantiles=quantiles, time_warmup=False)
289+
benchmark_suite.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch')
290+
_, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench(triton_fn, n_warmup=n_warmup, n_repeat=10,
291+
quantiles=quantiles, time_warmup=False)
292292
else:
293293
raise NotImplementedError(f'Unsupported provider {provider}')
294294

benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import triton
33
import triton.language as tl
44

5-
import triton_kernels_benchmark as benchmark_suit
5+
import triton_kernels_benchmark as benchmark_suite
66
from triton_kernels_benchmark import xetla_kernel
77

88

@@ -117,8 +117,8 @@ def forward(ctx, a, b, c, acc_dtype=None):
117117

118118

119119
# Benchmark Performance
120-
@benchmark_suit.perf_report(
121-
benchmark_suit.Benchmark(
120+
@benchmark_suite.perf_report(
121+
benchmark_suite.Benchmark(
122122
# argument names to use as an x-axis for the plot
123123
x_names=['M', 'K', 'N'],
124124
x_vals=[
@@ -149,16 +149,16 @@ def benchmark(M, N, K, provider):
149149
quantiles = [0.5, 0.0, 1.0]
150150

151151
if provider == 'onednn':
152-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=n_warmup,
153-
n_repeat=10, quantiles=quantiles)
152+
_, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench(lambda: torch.matmul(a, b), n_warmup=n_warmup,
153+
n_repeat=10, quantiles=quantiles)
154154
elif provider == 'triton':
155155
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
156156
triton_fn = lambda: matmul(a, b, c)
157157
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
158158
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
159-
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch')
160-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=n_warmup, n_repeat=10,
161-
quantiles=quantiles)
159+
benchmark_suite.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch')
160+
_, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench(triton_fn, n_warmup=n_warmup, n_repeat=10,
161+
quantiles=quantiles)
162162
elif provider == 'xetla':
163163
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
164164
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
@@ -169,9 +169,9 @@ def benchmark(M, N, K, provider):
169169
xetla_fn = lambda: func(a, b, c, acc, cnt)
170170
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
171171

172-
# benchmark_suit.assert_close(xetla_fn, torch_fn, atol=1e-4, rtol=1.0, err_msg='xetla to torch')
173-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=n_warmup, n_repeat=100,
174-
quantiles=quantiles)
172+
# benchmark_suite.assert_close(xetla_fn, torch_fn, atol=1e-4, rtol=1.0, err_msg='xetla to torch')
173+
_, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench(xetla_fn, n_warmup=n_warmup, n_repeat=100,
174+
quantiles=quantiles)
175175
else:
176176
raise NotImplementedError(f'Unsupported provider {provider}')
177177

benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import triton
1010
import triton.language as tl
1111

12-
import triton_kernels_benchmark as benchmark_suit
12+
import triton_kernels_benchmark as benchmark_suite
1313
from triton_kernels_benchmark import xetla_kernel
1414

1515

@@ -243,8 +243,8 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
243243

244244

245245
# Benchmark Performance
246-
@benchmark_suit.perf_report(
247-
benchmark_suit.Benchmark(
246+
@benchmark_suite.perf_report(
247+
benchmark_suite.Benchmark(
248248
# argument names to use as an x-axis for the plot
249249
x_names=['M', 'K', 'N'],
250250
x_vals=[[3072, 4096, 3072]],
@@ -271,15 +271,15 @@ def benchmark(M, N, K, provider):
271271
quantiles = [0.5, 0.0, 1.0]
272272

273273
if provider == 'onednn':
274-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=n_warmup,
275-
n_repeat=10, quantiles=quantiles)
274+
_, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench(lambda: torch.matmul(a, b), n_warmup=n_warmup,
275+
n_repeat=10, quantiles=quantiles)
276276
elif provider == 'triton':
277277
c = torch.zeros((M, N), device=a.device, dtype=torch.float32)
278278
triton_fn = lambda: matmul(a, b, c)
279279
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
280-
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=1e-2, err_msg='triton to torch')
281-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=n_warmup, n_repeat=10,
282-
quantiles=quantiles)
280+
benchmark_suite.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=1e-2, err_msg='triton to torch')
281+
_, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench(triton_fn, n_warmup=n_warmup, n_repeat=10,
282+
quantiles=quantiles)
283283
elif provider == 'xetla':
284284
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
285285
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
@@ -290,9 +290,9 @@ def benchmark(M, N, K, provider):
290290
xetla_fn = lambda: func(a, b, c, acc, cnt)
291291
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
292292

293-
# benchmark_suit.assert_close(xetla_fn, torch_fn, atol=1e-4, rtol=1.0, err_msg='xetla to torch')
294-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=n_warmup, n_repeat=10,
295-
quantiles=quantiles)
293+
# benchmark_suite.assert_close(xetla_fn, torch_fn, atol=1e-4, rtol=1.0, err_msg='xetla to torch')
294+
_, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench(xetla_fn, n_warmup=n_warmup, n_repeat=10,
295+
quantiles=quantiles)
296296
else:
297297
raise NotImplementedError(f'Unsupported provider {provider}')
298298

scripts/flash_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import triton
77

88
from triton_kernels_benchmark.flash_attention_benchmark import _attention, tune_attn_fwd
9-
import triton_kernels_benchmark as benchmark_suit
9+
import triton_kernels_benchmark as benchmark_suite
1010

1111

1212
def get_options():
@@ -75,7 +75,7 @@ def run(options):
7575
#torch.set_printoptions(profile="default") # reset
7676

7777
atol = 1e-1 if options.N_CTX == 16384 else 1e-2
78-
benchmark_suit.assert_close(lambda: triton_o, lambda: torch_o, atol=atol, rtol=1e-3, err_msg='triton to torch')
78+
benchmark_suite.assert_close(lambda: triton_o, lambda: torch_o, atol=atol, rtol=1e-3, err_msg='triton to torch')
7979

8080
if options.backward:
8181
triton_o.backward(torch.randn_like(triton_o), retain_graph=True)

0 commit comments

Comments
 (0)