Skip to content

Commit 0c9a92c

Browse files
authored
Refactor scripts in benchmarks to use flasinfer.testing.bench_gpu_time (#1337)
<!-- .github/pull_request_template.md --> ## 📌 Description PR includes mostly mechanical changes to benchmarking script & utilities: - Scripts in `benchmarks/bench_*.py` now use `flashinfer.testing.utils.bench_gpu_time` or `bench_gpu_time_with_cudagraph` instead of `triton.testing.do_bench` or `do_bench_cudagraph`. - Reported times are now **median** instead of **mean** (that was automatically computed by `triton.testing.do_bench`). Median times are more stable and better represent realistic times due to arithmetic means being susceptible to outliers and spikes. - All benchmark scripts have been tested to reproduce perf numbers from current scripts within noise. - Changes to `flashinfer.testing.utils.bench_gpu_time` or `bench_gpu_time_with_cudagraph`: - Removed NVTX range support. Experiments found that having NVTX ranges marked results in added overhead on the order of 10 usec, which makes measurements of fast kernels inaccurate. - The unnecessary `torch.cuda.synchronize()` after L2 flush was removed. The synchronizes also causes added overhead on the order of 10 usec, which makes measurements of fast kernels inaccurate. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent bc1a041 commit 0c9a92c

26 files changed

+340
-265
lines changed

benchmarks/bench_append_paged_kv_cache.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import dataclasses
33
from typing import Tuple, cast
44

5+
import numpy as np
56
import torch
6-
from triton.testing import do_bench
77

88
import flashinfer
9+
from flashinfer.testing.utils import bench_gpu_time
910

1011

1112
@dataclasses.dataclass(kw_only=True)
@@ -108,7 +109,8 @@ def fn_convert() -> Tuple[torch.Tensor, torch.Tensor]:
108109
)
109110

110111
batch_indices, positions = fn_convert()
111-
convert_latency_ms = cast(float, do_bench(fn_convert))
112+
convert_latencies = bench_gpu_time(fn_convert)
113+
convert_latency_ms = np.median(convert_latencies)
112114

113115
@torch.cuda.nvtx.range(f"append model={model_name}, seqlens={seqlens}")
114116
def fn() -> None:
@@ -124,7 +126,8 @@ def fn() -> None:
124126
"NHD",
125127
)
126128

127-
latency_ms = cast(float, do_bench(fn))
129+
latencies = bench_gpu_time(fn)
130+
latency_ms = np.median(latencies)
128131
all_layers_latency_ms = convert_latency_ms + latency_ms * model.num_layers
129132
throughput = (
130133
k.numel()

benchmarks/bench_append_paged_mla_kv_cache.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import dataclasses
33
from typing import Tuple, cast
44

5+
import numpy as np
56
import torch
6-
from triton.testing import do_bench
77

88
import flashinfer
9+
from flashinfer.testing.utils import bench_gpu_time
910

1011

1112
@dataclasses.dataclass(kw_only=True)
@@ -92,7 +93,8 @@ def fn_convert() -> Tuple[torch.Tensor, torch.Tensor]:
9293
)
9394

9495
batch_indices, positions = fn_convert()
95-
convert_latency_ms = cast(float, do_bench(fn_convert))
96+
convert_latencies = bench_gpu_time(fn_convert)
97+
convert_latency_ms = np.median(convert_latencies)
9698

9799
@torch.cuda.nvtx.range(f"append model={model_name}, seqlens={seqlens}")
98100
def fn() -> None:
@@ -108,7 +110,8 @@ def fn() -> None:
108110
kv_last_page_len,
109111
)
110112

111-
latency_ms = cast(float, do_bench(fn))
113+
latencies = bench_gpu_time(fn)
114+
latency_ms = np.median(latencies)
112115
all_layers_latency_ms = convert_latency_ms + latency_ms * model.num_layers
113116
throughput = (
114117
(ckv.numel() + kpe.numel())

benchmarks/bench_batch_attention.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import numpy as np
77
import pandas as pd
88
import torch
9-
from triton.testing import do_bench
109

1110
import flashinfer
11+
from flashinfer.testing.utils import bench_gpu_time
1212

1313

1414
def run_bench(
@@ -65,7 +65,8 @@ def run_bench(
6565
q_data_type=torch.bfloat16,
6666
kv_data_type=torch.bfloat16,
6767
)
68-
ms_old = do_bench(lambda: wrapper_old.run(q, kv_data))
68+
measurements_old = bench_gpu_time(lambda: wrapper_old.run(q, kv_data))
69+
ms_old = np.mean(measurements_old)
6970

7071
# new
7172
wrapper = flashinfer.BatchAttention(kv_layout="NHD")
@@ -83,7 +84,8 @@ def run_bench(
8384
q_data_type=torch.bfloat16,
8485
kv_data_type=torch.bfloat16,
8586
)
86-
ms_new = do_bench(lambda: wrapper.run(q, kv_data))
87+
measurements_new = bench_gpu_time(lambda: wrapper.run(q, kv_data))
88+
ms_new = np.mean(measurements_new)
8789

8890
total_bytes = (
8991
q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()

benchmarks/bench_batch_decode.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
import numpy as np
1818
import torch
19-
from triton.testing import do_bench
2019

2120
import flashinfer
21+
from flashinfer.testing.utils import bench_gpu_time
2222

2323
page_block_size = 16
2424
num_kv_heads = 4
@@ -67,7 +67,8 @@ def bench_batch_decode(
6767
q_data_type=q_dtype,
6868
)
6969

70-
ms = do_bench(lambda: wrapper.run(q, kv_data))
70+
measurements = bench_gpu_time(lambda: wrapper.run(q, kv_data))
71+
ms = np.median(measurements)
7172

7273
io = q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()
7374
print(

benchmarks/bench_blackwell_attention.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
limitations under the License.
1515
"""
1616

17+
import numpy as np
1718
import torch
18-
from triton.testing import do_bench
1919

2020
import flashinfer
21+
from flashinfer.testing.utils import bench_gpu_time
2122

2223

2324
def bench_fmha_blackwell(
@@ -61,11 +62,12 @@ def bench_fmha_blackwell(
6162
kv_data_type=dtype,
6263
)
6364
o = wrapper.run(q, k, v)
64-
ms = do_bench(
65+
measurements = bench_gpu_time(
6566
lambda: wrapper.run(q, k, v),
66-
warmup=100,
67-
rep=1000,
67+
dry_run_time_ms=100,
68+
repeat_time_ms=1000,
6869
)
70+
ms = np.median(measurements)
6971

7072
def flops(ms):
7173
if causal:

benchmarks/bench_block_sparse_attention.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
limitations under the License.
1515
"""
1616

17+
import numpy as np
1718
import torch
18-
import triton
1919

2020
import flashinfer
21+
from flashinfer.testing.utils import bench_gpu_time
2122

2223

2324
def bench_variable_block_sparse_attention(
@@ -86,27 +87,34 @@ def bench_variable_block_sparse_attention(
8687
q_data_type=torch.half,
8788
)
8889

89-
sparse_ms_fa2 = triton.testing.do_bench(
90+
# Benchmark sparse attention with FA2
91+
measurements_fa2 = bench_gpu_time(
9092
lambda: sparse_wrapper_fa2.run(q, k, v),
91-
warmup=100,
92-
rep=1000,
93+
dry_run_time_ms=100,
94+
repeat_time_ms=1000,
9395
)
94-
sparse_ms_fa3 = triton.testing.do_bench(
96+
sparse_ms_fa2 = np.median(measurements_fa2)
97+
98+
# Benchmark sparse attention with FA3
99+
measurements_fa3 = bench_gpu_time(
95100
lambda: sparse_wrapper_fa3.run(q, k, v),
96-
warmup=100,
97-
rep=1000,
101+
dry_run_time_ms=100,
102+
repeat_time_ms=1000,
98103
)
104+
sparse_ms_fa3 = np.median(measurements_fa3)
99105

100106
q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda")
101107
k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda")
102108
v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda")
103109
dense_sm80_ms, dense_sm90_ms = (
104-
triton.testing.do_bench(
105-
lambda: flashinfer.single_prefill_with_kv_cache_return_lse(
106-
q, k, v, causal=False, backend=backend
107-
),
108-
warmup=100,
109-
rep=1000,
110+
np.median(
111+
bench_gpu_time(
112+
lambda: flashinfer.single_prefill_with_kv_cache_return_lse(
113+
q, k, v, causal=False, backend=backend
114+
),
115+
dry_run_time_ms=100,
116+
repeat_time_ms=1000,
117+
)
110118
)
111119
for backend in ["fa2", "fa3"]
112120
)

benchmarks/bench_cutlass_fused_moe.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
import argparse
1818
import pprint
1919

20+
import numpy as np
2021
import torch
2122
from torch.nn import functional as F
2223

2324
import flashinfer.fused_moe as fused_moe
2425
from flashinfer import fp4_quantize
2526
from flashinfer.autotuner import AutoTuner, autotune, get_config_path
26-
from flashinfer.testing.utils import bench_gpu_time_with_cudagraph
27+
from flashinfer.testing.utils import bench_gpu_time
2728

2829
FLOAT4_E2M1_MAX = 6.0
2930
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
@@ -173,7 +174,7 @@ def bench_cutlass_fused_moe(
173174
output=flash_output,
174175
tune_max_num_tokens=16384,
175176
)
176-
ms_list = bench_gpu_time_with_cudagraph(
177+
ms_list = bench_gpu_time(
177178
lambda: fused_moe.cutlass_fused_moe(
178179
hidden_states,
179180
selected_experts.to(torch.int),
@@ -184,12 +185,12 @@ def bench_cutlass_fused_moe(
184185
quant_scales=quant_scales,
185186
input_sf=input_sf,
186187
output=flash_output,
187-
)
188+
),
188189
)
189-
avg_ms = sum(ms_list) / len(ms_list)
190+
median_ms = np.median(ms_list)
190191
print(f"{'input':<15} {'weight1':<20} {'weight2':<20} {'time(ms)'}")
191192
print(
192-
f"{str(tuple(hidden_states.shape)):<15} {str(tuple(w1.shape)):<20} {str(tuple(w2.shape)):<20} {avg_ms:.3f}"
193+
f"{str(tuple(hidden_states.shape)):<15} {str(tuple(w1.shape)):<20} {str(tuple(w2.shape)):<20} {median_ms:.3f}"
193194
)
194195

195196

benchmarks/bench_deepgemm_blackwell.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
limitations under the License.
1515
"""
1616

17+
import numpy as np
1718
import torch
18-
from triton.testing import do_bench
1919

2020
from flashinfer.gemm import (
2121
batch_deepgemm_fp8_nt_groupwise,
2222
group_deepgemm_fp8_nt_groupwise,
2323
)
24-
from flashinfer.testing.utils import quantize_fp8
24+
from flashinfer.testing.utils import bench_gpu_time, quantize_fp8
2525

2626

2727
def bench_deepgemm_grouped_fp8_blackwell(batch_size, m, n, k, in_dtype, out_dtype):
@@ -48,14 +48,14 @@ def bench_deepgemm_grouped_fp8_blackwell(batch_size, m, n, k, in_dtype, out_dtyp
4848
out = torch.empty(batch_size * m, n, device="cuda", dtype=out_dtype)
4949

5050
# Benchmark the DeepGEMM function
51-
ms = do_bench(
51+
measurements = bench_gpu_time(
5252
lambda: group_deepgemm_fp8_nt_groupwise(
5353
a_fp8, b_fp8, a_scale, b_scale, m_indices, out=out, out_dtype=out_dtype
5454
),
55-
warmup=100,
56-
rep=1000,
55+
dry_run_time_ms=100,
56+
repeat_time_ms=1000,
5757
)
58-
58+
ms = np.median(measurements)
5959
tflops_per_second = 2 * batch_size * m * n * k * 1e-9 / ms
6060
memory_bandwidth_per_second = (
6161
sum(
@@ -91,7 +91,7 @@ def bench_deepgemm_batch_fp8_blackwell(batch_size, m, n, k, in_dtype, out_dtype)
9191
out = torch.empty((batch_size, m, n), device="cuda", dtype=out_dtype)
9292

9393
# Benchmark the DeepGEMM function
94-
ms = do_bench(
94+
measurements = bench_gpu_time(
9595
lambda: batch_deepgemm_fp8_nt_groupwise(
9696
a_fp8,
9797
b_fp8,
@@ -102,9 +102,10 @@ def bench_deepgemm_batch_fp8_blackwell(batch_size, m, n, k, in_dtype, out_dtype)
102102
out=out,
103103
out_dtype=out_dtype,
104104
),
105-
warmup=100,
106-
rep=1000,
105+
dry_run_time_ms=100,
106+
repeat_time_ms=1000,
107107
)
108+
ms = np.median(measurements)
108109

109110
tflops_per_second = 2 * batch_size * m * n * k * 1e-9 / ms
110111
memory_bandwidth_per_second = (

benchmarks/bench_deepseek_mla.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
limitations under the License.
1515
"""
1616

17+
import numpy as np
1718
import torch
18-
import triton
1919

2020
import flashinfer
21+
from flashinfer.testing.utils import bench_gpu_time
2122

2223

2324
def bench_deepseek_mla_decode(batch_size, seq_len, num_heads, backend):
@@ -61,11 +62,12 @@ def bench_deepseek_mla_decode(batch_size, seq_len, num_heads, backend):
6162
)
6263
o = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=False)
6364

64-
ms = triton.testing.do_bench(
65+
measurements = bench_gpu_time(
6566
lambda: wrapper.run(q_nope, q_pe, ckv, kpe),
66-
warmup=100,
67-
rep=1000,
67+
dry_run_time_ms=100,
68+
repeat_time_ms=1000,
6869
)
70+
ms = np.median(measurements)
6971

7072
io = sum([_.numel() * _.element_size() for _ in [q_nope, q_pe, ckv, kpe, o]])
7173
flops = 2 * batch_size * num_heads * (2 * head_dim_ckv + head_dim_kpe) * seq_len

benchmarks/bench_fused_add_rmsnorm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import argparse
22
from typing import cast
33

4+
import numpy as np
45
import torch
5-
from triton.testing import do_bench
66

77
import flashinfer
8+
from flashinfer.testing.utils import bench_gpu_time
89

910

1011
@torch.inference_mode()
@@ -42,7 +43,8 @@ def fn() -> None:
4243
flashinfer.fused_add_rmsnorm(x, residual, weight, eps)
4344

4445
# Run benchmarking
45-
latency_ms = cast(float, do_bench(fn))
46+
measurements = bench_gpu_time(fn)
47+
latency_ms = np.median(measurements)
4648
throughput = (
4749
x.numel() * x.element_size() * 2
4850
+ residual.numel() * residual.element_size() * 2

0 commit comments

Comments
 (0)