Skip to content

Commit 19747b4

Browse files
authored
benchmark: Addition of flashinfer_benchmark.py for benchmarking routines (#1323)
<!-- .github/pull_request_template.md --> ## 📌 Description Adds benchmarks/flashinfer_benchmark.py and utility functions for benchmarking performance of various FI APIs. Test harness supports three attention backends and two GEMM backends, leaving room for future expansion. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent ba3f324 commit 19747b4

12 files changed

+3171
-0
lines changed

benchmarks/README.md

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# FlashInfer Perf Benchmarking Framework -- flashinfer_benchmark.py
2+
3+
A comprehensive testing and benchmarking framework for FlashInfer's kernels.
4+
5+
The aim of `flashinfer_benchmark.py` is to provide a single framework for benchmarking any FlashInfer kernel and replace standalone benchmarking scripts. Current support surface includes batched prefill & decode, and FP8 gemm.
6+
7+
## Overview
8+
9+
This framework provides tools to:
10+
- Benchmark different attention implementations (FlashAttention2/3, cuDNN, CUTLASS, TensorRT-LLM)
11+
- Benchmark the GEMM performance.
12+
- Compare performance across different configurations
13+
- Batch performance test multiple attention test cases
14+
- Generate detailed performance reports
15+
16+
Currently supports testing:
17+
- `BatchDecodeWithPagedKVCacheWrapper` - Decode attention with paged KV cache
18+
- `BatchPrefillWithPagedKVCacheWrapper` - Prefill attention with paged KV cache
19+
- `BatchPrefillWithRaggedKVCacheWrapper` - Prefill attention with ragged KV cache
20+
- `gemm_fp8_nt_groupwise` - GEMM with FP8 data types using groupwise scaling.
21+
- `group_gemm_fp8_nt_groupwise` - Group GEMM with FP8 data types using groupwise scaling.
22+
23+
Support surface will expand to other operations such as MLA or non-attention operations in the future.
24+
## Quick Start
25+
26+
### Single Test Run
27+
Example commands
28+
```bash
29+
# Test prefill attention with paged KV cache
30+
python3 flashinfer_benchmark.py \
31+
--routine BatchPrefillWithPagedKVCacheWrapper \
32+
--backends fa2 cudnn \
33+
--page_size 16 \
34+
--batch_size 16 \
35+
--s_qo 4096 \
36+
--s_kv 4096 \
37+
--num_qo_heads 64 \
38+
--num_kv_heads 8 \
39+
--head_dim_qk 128 \
40+
--head_dim_vo 128 \
41+
--random_actual_seq_len \
42+
--verbose \
43+
--refcheck \
44+
--causal \
45+
--no_cuda_graph
46+
47+
# Test prefill attention with ragged KV cache
48+
python3 flashinfer_benchmark.py \
49+
--routine BatchPrefillWithRaggedKVCacheWrapper \
50+
--backends fa2 cudnn cutlass \
51+
--batch_size 16 \
52+
--s_qo 4096 \
53+
--s_kv 4096 \
54+
--num_qo_heads 128 \
55+
--num_kv_heads 128 \
56+
--head_dim_qk 192 \
57+
--head_dim_vo 128 \
58+
--verbose \
59+
--refcheck \
60+
--causal \
61+
--no_cuda_graph
62+
63+
# Test decode attention with paged KV cache
64+
python3 flashinfer_benchmark.py \
65+
--routine BatchDecodeWithPagedKVCacheWrapper \
66+
--backends fa2 fa2_tc trtllm cudnn \
67+
--page_size 16 \
68+
--batch_size 16 \
69+
--s_qo 1 \
70+
--s_kv 8192 \
71+
--num_qo_heads 64 \
72+
--num_kv_heads 8 \
73+
--head_dim_qk 128 \
74+
--head_dim_vo 128 \
75+
--random_actual_seq_len \
76+
--verbose \
77+
--refcheck
78+
79+
# FP8 GEMM
80+
python3 flashinfer_benchmark.py \
81+
--routine gemm_fp8_nt_groupwise \
82+
--m 8192 \
83+
--n 4096 \
84+
--k 16384 \
85+
--mma_sm 2 \
86+
--refcheck \
87+
-vv
88+
89+
# Group FP8 GEMM
90+
python3 flashinfer_benchmark.py \
91+
--routine group_gemm_fp8_nt_groupwise \
92+
--m 8192 \
93+
--n 4096 \
94+
--k 16384 \
95+
--mma_sm 2 \
96+
--group_size 2 \
97+
--no_cuda_graph \
98+
--scale_major_mode K \
99+
--refcheck \
100+
-vv
101+
```
102+
103+
### Batch Testing
104+
105+
Run multiple tests from a file and save results:
106+
```bash
107+
python3 flashinfer_benchmark.py --testlist samples/sample_testlist.txt --output_path sample_testlist_output.csv
108+
```
109+
110+
The output CSV will contain detailed metrics including:
111+
- Median execution time
112+
- Standard deviation
113+
- TFLOPS/sec
114+
- Memory throughput (TB/sec)
115+
116+
## Command Line Arguments
117+
### General Flags
118+
| Flag | Description |
119+
|--------------------------|-------------------------------------------------------------------------------------------------------------|
120+
| `--routine` | Test routine to run: `BatchDecodeWithPagedKVCacheWrapper`, `BatchPrefillWithPagedKVCacheWrapper`, or `BatchPrefillWithRaggedKVCacheWrapper`, `gemm_fp8_nt_groupwise`, `group_gemm_fp8_nt_groupwise`|
121+
| `--num_iters` | Number of iterations for performance measurement |
122+
| `--dry_run_iters` | Number of warmup iterations |
123+
| `--no_cuda_graph` | Disable CUDA graph to execute kernels outside of the graph.
124+
| `--allow_output_mismatch`| Continue testing even if outputs don't match between backends |
125+
| `--refcheck` | Verify outputs match between different backends |
126+
| `--random_seed` | Random seed for reproducibility |
127+
| `--output_path` | Path to save CSV results |
128+
| `--verbose`, `-v` | Print additional information |
129+
130+
### Attention Flags
131+
| Flag | Description |
132+
|--------------------------|-------------------------------------------------------------------------------------------------------------|
133+
| `--backends` | List of backends to test: fa2, fa2_tc, fa3, cudnn, cutlass, trtllm |
134+
| `--page_size` | Page size for paged attention. Required for paged attention tests. |
135+
| `--batch_size` | Number of sequences to process in parallel |
136+
| `--s_qo` | Query/output sequence length. Should be 1 for decode tests. |
137+
| `--s_kv` | Key/value sequence length (context length) |
138+
| `--num_qo_heads` | Number of query/output attention heads |
139+
| `--num_kv_heads` | Number of key/value attention heads
140+
| `--head_dim_qk` | Head dimension for Q/K. Must be 128 or 192. |
141+
| `--head_dim_vo` | Head dimension for V/O. Usually equals head_dim_qk. |
142+
| `--q_dtype` | Data type for the query tensor. Default: bfloat16. Currently only bfloat16 is supported. |
143+
| `--kv_dtype` | Data type for the key and value tensors. Default: bfloat16. Currently only bfloat16 is supported. |
144+
| `--causal` | Use causal attention masking (prefill only) |
145+
| `--random_actual_seq_len`| Use random sequence lengths up to max length. If False, use max length. |
146+
147+
### GEMM Flags
148+
| Flag | Description |
149+
|--------------------------|-------------------------------------------------------------------------------------------------------------|
150+
| `--m` | Number of rows of matrix A and output matrix (GEMM M dimension) |
151+
| `--n` | Number of columns of matrix B and output matrix (GEMM N dimension) |
152+
| `--k` | Number of columns of matrix A / rows of matrix B (GEMM K dimension) |
153+
| `--tile_size` | Tile size for the GEMM operation (affects performance and scaling) |
154+
| `--group_size` | Number of groups for group GEMM (batching multiple GEMMs together) |
155+
| `--scale_major_mode` | Layout for FP8 scaling: `MN` (per output tile) or `K` (per input tile) |
156+
| `--out_dtype` | Output data type: `bfloat16` or `float16` |
157+
| `--mma_sm` | Number of SMs to use for the MMA operation (1 or 2) |
158+
159+
## Tester Attention Backend Support Matrix
160+
The following support surface applies to attention operations in `flashinfer_benchmark.py`
161+
| Backend | Decode Paged | Prefill Paged | Prefill Ragged | FP8 | Notes |
162+
|----------|-------------|---------------|----------------|------|------------------------------------------|
163+
| fa2 ||||| Does not support GQA ratio of 5 |
164+
| fa2_tc ||||| Uses tensor cores |
165+
| fa3 ||||| Hopper Only |
166+
| cudnn ||* |* || *Requires specific head dims (192 or 128) |
167+
| cutlass ||||| |
168+
| trtllm ||||| |
169+
170+
Notes:
171+
- Currently only support bfloat16 attention only.
172+
- CUDA graph support is only stable with BatchDecodeWithPagedKVCacheWrapper. For BatchPrefillWithPagedKVCacheWrapper and BatchPrefillWithRaggedKVCacheWrapper, it is recommended that `--no_cuda_graph` is used.
173+
- cudnn, cutlass, and trtllm backends are supported on [CUDA Compute Capability 10.0 GPUs](https://developer.nvidia.com/cuda-gpus) only.
174+
- fa3 is supported on [CUDA Compute Capability 9.0 GPUs](https://developer.nvidia.com/cuda-gpus) only.
175+
176+
## Example Outputs
177+
```bash
178+
$ python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 cudnn --page_size 16 --batch_size 16 --s_qo 1024 --s_kv 1024 --num_qo_heads 8 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len --refcheck --causal --no_cuda_graph -vv
179+
[INFO] args = Namespace(routine='BatchPrefillWithPagedKVCacheWrapper', backends=['fa2', 'cudnn'], page_size=16, batch_size=16, s_qo=1024, s_kv=1024, num_qo_heads=8, num_kv_heads=8, head_dim_qk=128, head_dim_vo=128, q_dtype='bfloat16', kv_dtype='bfloat16', causal=True, num_iters=30, dry_run_iters=5, no_cuda_graph=True, random_actual_seq_len=True, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None)
180+
[INFO] FlashInfer version: 0.2.8
181+
[INFO] Running testBatchPrefillWithPagedKVCacheWrapper
182+
[VVERBOSE] gpu_name = 'NVIDIA_B200'
183+
[VERBOSE] Average actual seq len: 327
184+
[VVERBOSE] actual_seq_lens_q.flatten() = tensor([103, 436, 861, 271, 107, 72, 701, 21, 615, 122, 467, 215, 331, 459,
185+
88, 373], dtype=torch.int32)
186+
[VVERBOSE] q.shape = torch.Size([5242, 8, 128])
187+
[VVERBOSE] num_pages_per_seq = 64
188+
[VVERBOSE] total_num_pages = 1024
189+
[VVERBOSE] kv_cache.shape = torch.Size([1024, 2, 8, 16, 128])
190+
[VVERBOSE] kv_cache.stride() = (32768, 16384, 128, 1024, 1)
191+
[VVERBOSE] block_tables.shape = torch.Size([16, 64])
192+
[VVERBOSE] qo_indptr.shape = torch.Size([17])
193+
[VVERBOSE] qo_indptr.dtype = torch.int32
194+
[VVERBOSE] kv_indptr.shape = torch.Size([17])
195+
[VVERBOSE] kv_indices.shape = torch.Size([335])
196+
[VVERBOSE] kv_last_page_len.shape = torch.Size([16])
197+
[VVERBOSE] scale = 0.08838834764831843
198+
[PERF] fa2 :: median time 0.094 ms; std 0.022 ms; achieved tflops 57.960 TFLOPs/sec; achieved tb_per_sec 0.459 TB/sec
199+
[PERF] cudnn :: median time 0.096 ms; std 0.021 ms; achieved tflops 56.661 TFLOPs/sec; achieved tb_per_sec 0.449 TB/sec
200+
201+
$ python3 flashinfer_benchmark.py --routine BatchPrefillWithRaggedKVCacheWrapper --backends fa2 cudnn cutlass --batch_size 16 --s_qo 1024 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_qk 192 --head_dim_vo 128 --refcheck --causal --no_cuda_graph -vv
202+
[INFO] args = Namespace(routine='BatchPrefillWithRaggedKVCacheWrapper', backends=['fa2', 'cudnn', 'cutlass'], page_size=0, batch_size=16, s_qo=1024, s_kv=1024, num_qo_heads=128, num_kv_heads=128, head_dim_qk=192, head_dim_vo=128, q_dtype='bfloat16', kv_dtype='bfloat16', causal=True, num_iters=30, dry_run_iters=5, no_cuda_graph=True, random_actual_seq_len=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=True, vverbose=True, output_path=None)
203+
[INFO] FlashInfer version: 0.2.8
204+
[INFO] Running testBatchPrefillWithRaggedKVCacheWrapper
205+
[VVERBOSE] gpu_name = 'NVIDIA_B200'
206+
[VERBOSE] Average actual seq len: 1024
207+
[VVERBOSE] actual_seq_lens_q.flatten() = tensor([1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024,
208+
1024, 1024, 1024, 1024], dtype=torch.int32)
209+
[VVERBOSE] q.shape = torch.Size([16384, 128, 192])
210+
[VVERBOSE] k.shape = torch.Size([16384, 128, 192])
211+
[VVERBOSE] v.shape = torch.Size([16384, 128, 128])
212+
[VVERBOSE] qo_indptr.shape = torch.Size([17])
213+
[VVERBOSE] kv_indptr.shape = torch.Size([17])
214+
[VVERBOSE] scale = 0.07216878364870323
215+
[PERF] fa2 :: median time 2.197 ms; std 0.011 ms; achieved tflops 312.787 TFLOPs/sec; achieved tb_per_sec 1.222 TB/sec
216+
[PERF] cudnn :: median time 1.008 ms; std 0.014 ms; achieved tflops 681.979 TFLOPs/sec; achieved tb_per_sec 2.664 TB/sec
217+
[PERF] cutlass :: median time 1.453 ms; std 0.021 ms; achieved tflops 473.035 TFLOPs/sec; achieved tb_per_sec 1.848 TB/sec
218+
219+
$ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc trtllm cudnn --page_size 16 --batch_size 32 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --refcheck -vv
220+
[INFO] args = Namespace(routine='BatchDecodeWithPagedKVCacheWrapper', backends=['fa2', 'fa2_tc', 'trtllm', 'cudnn'], page_size=16, batch_size=32, s_qo=1, s_kv=8192, num_qo_heads=64, num_kv_heads=8, head_dim_qk=128, head_dim_vo=128, q_dtype='bfloat16', kv_dtype='bfloat16', causal=False, num_iters=30, dry_run_iters=5, no_cuda_graph=False, random_actual_seq_len=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None)
221+
[INFO] FlashInfer version: 0.2.8
222+
[INFO] Running testBatchDecodeWithPagedKVCacheWrapper
223+
[VVERBOSE] gpu_name = 'NVIDIA_B200'
224+
[VERBOSE] Average actual seq len: 8192
225+
[VVERBOSE] actual_seq_lens_kv.flatten() = tensor([8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192,
226+
8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192,
227+
8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], device='cuda:0',
228+
dtype=torch.int32)
229+
[VVERBOSE] q.shape = torch.Size([32, 64, 128])
230+
[VVERBOSE] num_pages_per_seq = 512
231+
[VVERBOSE] total_num_pages = 16384
232+
[VVERBOSE] kv_cache.shape = torch.Size([16384, 2, 8, 16, 128])
233+
[VVERBOSE] kv_cache.stride() = (32768, 16384, 128, 1024, 1)
234+
[VVERBOSE] block_tables.shape = torch.Size([32, 512])
235+
[VVERBOSE] kv_indptr.shape = torch.Size([33])
236+
[VVERBOSE] kv_indices.shape = torch.Size([16384])
237+
[VVERBOSE] kv_last_page_len.shape = torch.Size([32])
238+
[VVERBOSE] scale = 0.08838834764831843
239+
[PERF] fa2 :: median time 0.712 ms; std 0.000 ms; achieved tflops 12.061 TFLOPs/sec; achieved tb_per_sec 1.509 TB/sec
240+
[PERF] fa2_tc :: median time 0.173 ms; std 0.001 ms; achieved tflops 49.779 TFLOPs/sec; achieved tb_per_sec 6.228 TB/sec
241+
[PERF] trtllm :: median time 0.155 ms; std 0.000 ms; achieved tflops 55.344 TFLOPs/sec; achieved tb_per_sec 6.925 TB/sec
242+
[PERF] cudnn :: median time 0.253 ms; std 0.000 ms; achieved tflops 33.964 TFLOPs/sec; achieved tb_per_sec 4.250 TB/sec
243+
```

0 commit comments

Comments
 (0)