Skip to content

Commit f16b149

Browse files
committed
fixes
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent a1fd0f9 commit f16b149

File tree

8 files changed

+18
-18
lines changed

8 files changed

+18
-18
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
242242
if provider == 'onednn':
243243
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(
244244
lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=
245-
CAUSAL, scale=sm_scale), warmup=10, rep=10,
245+
CAUSAL, scale=sm_scale), m_warmup=10, n_rep=10,
246246
quantiles=quantiles)
247247

248248
elif provider == 'triton':
@@ -256,7 +256,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
256256
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
257257
atol = 1e-1 if N_CTX == 16384 else 1e-2
258258
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
259-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
259+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_rep=10, quantiles=quantiles,
260260
kernel_name='_attn_fwd')
261261

262262
elif provider == 'xetla':
@@ -272,7 +272,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
272272
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
273273

274274
xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
275-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
275+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_rep=10, quantiles=quantiles,
276276
kernel_name='gpu::xetla::fmha::FmhaForwardKernel<')
277277

278278
else:

benchmarks/triton_kernels_benchmark/fused_softmax.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,18 @@ def benchmark(M, N, provider):
125125
quantiles = [0.5, 0.0, 1.0]
126126
if provider == "torch-native":
127127
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles,
128-
warmup=10, rep=10)
128+
n_warmup=10, n_rep=10)
129129
if provider == "triton":
130130
out = torch.empty_like(x, device="xpu")
131131
triton_fn = lambda: softmax(x, out)
132132
torch_fn = lambda: torch.softmax(x, axis=-1)
133133
benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch")
134-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10,
134+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, n_warmup=10, n_rep=10,
135135
kernel_name="softmax_kernel")
136136

137137
elif provider == "torch-jit":
138-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, warmup=10,
139-
rep=10)
138+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles,
139+
n_warmup=10, n_rep=10)
140140

141141
elif provider == "xetla":
142142
name = f"softmax_shape_{M}_{N}"
@@ -154,7 +154,7 @@ def benchmark(M, N, provider):
154154
"softmax_shape_4096_16384": "mat1_4096x16k_bf16_cfg0",
155155
"softmax_shape_4096_32768": "mat1_4096x32k_bf16_cfg0",
156156
}
157-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10,
157+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, n_warmup=10, n_rep=10,
158158
kernel_name=kernels_name[name])
159159

160160
else:

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def benchmark(B, M, N, K, provider):
283283
if BENCHMARKING_METHOD == 'PYTORCH_LEGACY_PROFILER_USING_IPEX':
284284
# Legacy profiler shows ~6000TFLOPS GeoMean for onednn measurements, so use more reliable method
285285
do_bench = do_bench_elapsed_time
286-
_, min_ms, max_ms, mean_ms, cv = do_bench(lambda: torch.matmul(torch_a, torch_b), warmup=10, rep=10,
286+
_, min_ms, max_ms, mean_ms, cv = do_bench(lambda: torch.matmul(torch_a, torch_b), n_warmup=10, n_rep=10,
287287
quantiles=quantiles, kernel_name='gemm_kernel')
288288
elif provider == 'triton':
289289
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
@@ -296,7 +296,7 @@ def benchmark(B, M, N, K, provider):
296296
torch_fn = lambda: torch.matmul(torch_a, torch_b).to(torch.float32)
297297
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
298298
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
299-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
299+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_rep=10, quantiles=quantiles,
300300
kernel_name='matmul_kernel_with_block_pointers')
301301
elif provider == 'xetla':
302302
if B == 1:
@@ -340,7 +340,7 @@ def benchmark(B, M, N, K, provider):
340340
}
341341

342342
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
343-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
343+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_rep=10, quantiles=quantiles,
344344
kernel_name=kernels_name[name])
345345
else:
346346
raise NotImplementedError(f'Unsupported provider {provider}')

benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def benchmark(B, M, N, K, provider):
275275
torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d
276276
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
277277
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
278-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
278+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_rep=10, quantiles=quantiles,
279279
kernel_name=kernel_name)
280280
else:
281281
raise NotImplementedError(f'Unsupported provider {provider}')

benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def benchmark(B, M, N, K, provider):
277277
torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32))
278278
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
279279
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
280-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
280+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_rep=10, quantiles=quantiles,
281281
kernel_name=kernel_name)
282282
else:
283283
raise NotImplementedError(f'Unsupported provider {provider}')

benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def benchmark(B, M, N, K, provider):
265265
torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32)
266266
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
267267
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
268-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
268+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_rep=10, quantiles=quantiles,
269269
kernel_name=kernel_name)
270270
else:
271271
raise NotImplementedError(f'Unsupported provider {provider}')

benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,15 +148,15 @@ def benchmark(M, N, K, provider):
148148
quantiles = [0.5, 0.0, 1.0]
149149

150150
if provider == 'onednn':
151-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
151+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_rep=10,
152152
quantiles=quantiles)
153153
elif provider == 'triton':
154154
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
155155
triton_fn = lambda: matmul(a, b, c)
156156
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
157157
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
158158
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
159-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
159+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_rep=10, quantiles=quantiles,
160160
kernel_name='_kernel')
161161
else:
162162
raise NotImplementedError(f'Unsupported provider {provider}')

benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,14 +271,14 @@ def benchmark(M, N, K, provider):
271271
quantiles = [0.5, 0.0, 1.0]
272272

273273
if provider == 'onednn':
274-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
274+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_rep=10,
275275
quantiles=quantiles)
276276
elif provider == 'triton':
277277
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
278278
triton_fn = lambda: matmul(a, b, c)
279279
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
280280
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch')
281-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
281+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_rep=10, quantiles=quantiles,
282282
kernel_name=['first_wave', 'full_tiles'])
283283
else:
284284
raise NotImplementedError(f'Unsupported provider {provider}')

0 commit comments

Comments
 (0)