Skip to content

Commit 7ce3fa9

Browse files
authored
[CI][benchmarks] Fixed warmup type for flash attention and gemm-preop-exp (#5344)
This is a follow up to #5293 In that PR there was no change to the warmup type, hence we currently too much warmup and CI is too slow.
1 parent c795abd commit 7ce3fa9

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
614614
err_msg=f'Error comparing {name} between triton and torch')
615615
triton_fn = lambda: triton_o.backward(dout, retain_graph=True)
616616

617-
_, min_ms, max_ms, mean, cv = do_bench(triton_fn, grad_to_none=(q, k, v), time_warmup=False)
617+
_, min_ms, max_ms, mean, cv = do_bench(triton_fn, grad_to_none=(q, k, v))
618618

619619
elif provider == 'xetla':
620620
if MODE == 'bwd':
@@ -644,7 +644,7 @@ def xetla_bwd_fn():
644644
bias_strideN, bias_strideF, attn_mask_padding)
645645
return out
646646

647-
_, min_ms, max_ms, mean, cv = do_bench(xetla_bwd_fn, time_warmup=False)
647+
_, min_ms, max_ms, mean, cv = do_bench(xetla_bwd_fn)
648648

649649
else:
650650
min_ms = float('nan')
@@ -664,7 +664,7 @@ def cutlass_fwd_fn():
664664

665665
benchmark_suite.assert_close(cutlass_fwd_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='cutlass to torch')
666666

667-
_, min_ms, max_ms, mean, cv = do_bench(cutlass_fwd_fn, time_warmup=False)
667+
_, min_ms, max_ms, mean, cv = do_bench(cutlass_fwd_fn)
668668

669669
else:
670670
min_ms = float('nan')

benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def benchmark(B, M, N, K, provider):
277277
torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32)
278278
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
279279
benchmark_suite.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch')
280-
_, min_ms, max_ms, mean_ms, cv = do_bench(triton_fn, time_warmup=False)
280+
_, min_ms, max_ms, mean_ms, cv = do_bench(triton_fn)
281281
else:
282282
raise NotImplementedError(f'Unsupported provider {provider}')
283283

0 commit comments

Comments
 (0)