Skip to content

Commit 361dfa7

Browse files
authored
Don't use fast_flush=False as it seems to be deprecated (#2323)
Closes #2324 Note: PyTorch remove it as well: pytorch/pytorch#135387 CI: * ~https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/11015650093~ * https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/11078059091 Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 77d819c commit 361dfa7

File tree

8 files changed

+14
-23
lines changed

8 files changed

+14
-23
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
238238
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(
239239
lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=
240240
CAUSAL, scale=sm_scale), warmup=10, rep=10,
241-
quantiles=quantiles, fast_flush=False)
241+
quantiles=quantiles)
242242

243243
elif provider == 'triton':
244244
# FIXME: remove below if condition when extend attention support for Causal = True done
@@ -257,8 +257,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
257257
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
258258
atol = 1e-1 if N_CTX == 16384 else 1e-2
259259
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
260-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
261-
fast_flush=False)
260+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
262261

263262
elif provider == 'xetla':
264263
module_name = f'flash_attn_causal_{CAUSAL}'.lower()
@@ -273,8 +272,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
273272
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
274273

275274
xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
276-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
277-
fast_flush=False)
275+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles)
278276

279277
else:
280278
raise NotImplementedError(f'Unsupported provider {provider}')

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def benchmark(B, M, N, K, provider):
250250

251251
if provider == 'onednn':
252252
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
253-
quantiles=quantiles, fast_flush=False)
253+
quantiles=quantiles)
254254
elif provider == 'triton':
255255
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
256256
if len(a.shape) == 3:
@@ -262,8 +262,7 @@ def benchmark(B, M, N, K, provider):
262262
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
263263
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
264264
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
265-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
266-
fast_flush=False)
265+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
267266
elif provider == 'xetla':
268267
if B == 1:
269268
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
@@ -278,8 +277,7 @@ def benchmark(B, M, N, K, provider):
278277
xetla_fn = lambda: func(a, b, c, acc, cnt)
279278
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
280279
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
281-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
282-
fast_flush=False)
280+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles)
283281
else:
284282
raise NotImplementedError(f'Unsupported provider {provider}')
285283

benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,7 @@ def benchmark(B, M, N, K, provider):
273273
torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d
274274
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
275275
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
276-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
277-
fast_flush=False)
276+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
278277
else:
279278
raise NotImplementedError(f'Unsupported provider {provider}')
280279

benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,7 @@ def benchmark(B, M, N, K, provider):
275275
torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32))
276276
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
277277
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
278-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
279-
fast_flush=False)
278+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
280279
else:
281280
raise NotImplementedError(f'Unsupported provider {provider}')
282281

benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,7 @@ def benchmark(B, M, N, K, provider):
263263
torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32)
264264
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
265265
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
266-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
267-
fast_flush=False)
266+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
268267
else:
269268
raise NotImplementedError(f'Unsupported provider {provider}')
270269

benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,14 @@ def benchmark(M, N, K, provider):
150150

151151
if provider == 'onednn':
152152
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
153-
quantiles=quantiles, fast_flush=False)
153+
quantiles=quantiles)
154154
elif provider == 'triton':
155155
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
156156
triton_fn = lambda: matmul(a, b, c)
157157
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
158158
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
159159
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
160-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
161-
fast_flush=False)
160+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
162161
else:
163162
raise NotImplementedError(f'Unsupported provider {provider}')
164163

benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,14 +272,13 @@ def benchmark(M, N, K, provider):
272272

273273
if provider == 'onednn':
274274
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
275-
quantiles=quantiles, fast_flush=False)
275+
quantiles=quantiles)
276276
elif provider == 'triton':
277277
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
278278
triton_fn = lambda: matmul(a, b, c)
279279
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
280280
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch')
281-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
282-
fast_flush=False)
281+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
283282
else:
284283
raise NotImplementedError(f'Unsupported provider {provider}')
285284

benchmarks/triton_kernels_benchmark/prefix_sums.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def benchmark(M, N, AXIS, provider):
4444

4545
if provider == 'triton':
4646
triton_fn = lambda: scan_kernel[(1, )](x, BLOCK_SIZE_M=M, BLOCK_SIZE_N=N, AXIS=AXIS)
47-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, fast_flush=False)
47+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles)
4848
else:
4949
raise NotImplementedError(f'Unsupported provider {provider}')
5050

0 commit comments

Comments
 (0)