Skip to content
Merged
Changes from 3 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
b1d2a0b
Increase 'warmup' and 'rep' for FA benchmark
anmyachev Sep 16, 2024
339b709
Merge branch 'main' into amyachev/bench-time
anmyachev Sep 16, 2024
5ebbd01
Use 150ms
anmyachev Sep 16, 2024
b1cc599
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 17, 2024
0ad146f
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 17, 2024
bbf0557
Merge branch 'main' into amyachev/bench-time
anmyachev Sep 19, 2024
81fec9a
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 23, 2024
42e653a
Merge branch 'main' into amyachev/bench-time
anmyachev Sep 23, 2024
8f81c13
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 23, 2024
5d08d3a
Merge branch 'main' into amyachev/bench-time
anmyachev Sep 29, 2024
b2d3398
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 29, 2024
fe806b1
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 30, 2024
bf49b0d
fix after merge
anmyachev Sep 30, 2024
7493632
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 30, 2024
524f81d
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 30, 2024
4d40864
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 30, 2024
b0d91ce
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 30, 2024
6809b9a
Merge remote-tracking branch 'origin' into amyachev/bench-time
anmyachev Oct 14, 2024
e1c4f9f
Change do_bench* signatures
anmyachev Oct 14, 2024
a1fd0f9
cleanup
anmyachev Oct 14, 2024
f16b149
fixes
anmyachev Oct 14, 2024
565d87c
fix
anmyachev Oct 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,11 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider):
v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype)
sm_scale = 0.125
quantiles = [0.5, 0.0, 1.0]
warmup, rep = 150, 150
if provider == 'onednn':
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(
lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=
False, scale=sm_scale), warmup=10, rep=10,
False, scale=sm_scale), warmup=warmup, rep=rep,
quantiles=quantiles, fast_flush=False)

elif provider == 'triton':
Expand All @@ -231,7 +232,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider):
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=sm_scale).to(torch.float32)
atol = 1e-1 if N_CTX == 16384 else 1e-2
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=warmup, rep=rep, quantiles=quantiles,
fast_flush=False)

elif provider == 'xetla':
Expand All @@ -246,7 +247,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider):
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)

xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=warmup, rep=rep, quantiles=quantiles,
fast_flush=False)

else:
Expand Down