|
| 1 | +# FlashInfer Perf Benchmarking Framework -- flashinfer_benchmark.py |
| 2 | + |
| 3 | +A comprehensive testing and benchmarking framework for FlashInfer's kernels. |
| 4 | + |
| 5 | +The aim of `flashinfer_benchmark.py` is to provide a single framework for benchmarking any FlashInfer kernel and replace standalone benchmarking scripts. Current support surface includes batched prefill & decode, and FP8 gemm. |
| 6 | + |
| 7 | +## Overview |
| 8 | + |
| 9 | +This framework provides tools to: |
| 10 | +- Benchmark different attention implementations (FlashAttention2/3, cuDNN, CUTLASS, TensorRT-LLM) |
| 11 | +- Benchmark the GEMM performance. |
| 12 | +- Compare performance across different configurations |
| 13 | +- Batch performance test multiple attention test cases |
| 14 | +- Generate detailed performance reports |
| 15 | + |
| 16 | +Currently supports testing: |
| 17 | +- `BatchDecodeWithPagedKVCacheWrapper` - Decode attention with paged KV cache |
| 18 | +- `BatchPrefillWithPagedKVCacheWrapper` - Prefill attention with paged KV cache |
| 19 | +- `BatchPrefillWithRaggedKVCacheWrapper` - Prefill attention with ragged KV cache |
| 20 | +- `gemm_fp8_nt_groupwise` - GEMM with FP8 data types using groupwise scaling. |
| 21 | +- `group_gemm_fp8_nt_groupwise` - Group GEMM with FP8 data types using groupwise scaling. |
| 22 | + |
| 23 | +Support surface will expand to other operations such as MLA or non-attention operations in the future. |
| 24 | +## Quick Start |
| 25 | + |
| 26 | +### Single Test Run |
| 27 | +Example commands |
| 28 | +```bash |
| 29 | +# Test prefill attention with paged KV cache |
| 30 | +python3 flashinfer_benchmark.py \ |
| 31 | + --routine BatchPrefillWithPagedKVCacheWrapper \ |
| 32 | + --backends fa2 cudnn \ |
| 33 | + --page_size 16 \ |
| 34 | + --batch_size 16 \ |
| 35 | + --s_qo 4096 \ |
| 36 | + --s_kv 4096 \ |
| 37 | + --num_qo_heads 64 \ |
| 38 | + --num_kv_heads 8 \ |
| 39 | + --head_dim_qk 128 \ |
| 40 | + --head_dim_vo 128 \ |
| 41 | + --random_actual_seq_len \ |
| 42 | + --verbose \ |
| 43 | + --refcheck \ |
| 44 | + --causal \ |
| 45 | + --no_cuda_graph |
| 46 | + |
| 47 | +# Test prefill attention with ragged KV cache |
| 48 | +python3 flashinfer_benchmark.py \ |
| 49 | + --routine BatchPrefillWithRaggedKVCacheWrapper \ |
| 50 | + --backends fa2 cudnn cutlass \ |
| 51 | + --batch_size 16 \ |
| 52 | + --s_qo 4096 \ |
| 53 | + --s_kv 4096 \ |
| 54 | + --num_qo_heads 128 \ |
| 55 | + --num_kv_heads 128 \ |
| 56 | + --head_dim_qk 192 \ |
| 57 | + --head_dim_vo 128 \ |
| 58 | + --verbose \ |
| 59 | + --refcheck \ |
| 60 | + --causal \ |
| 61 | + --no_cuda_graph |
| 62 | + |
| 63 | +# Test decode attention with paged KV cache |
| 64 | +python3 flashinfer_benchmark.py \ |
| 65 | + --routine BatchDecodeWithPagedKVCacheWrapper \ |
| 66 | + --backends fa2 fa2_tc trtllm cudnn \ |
| 67 | + --page_size 16 \ |
| 68 | + --batch_size 16 \ |
| 69 | + --s_qo 1 \ |
| 70 | + --s_kv 8192 \ |
| 71 | + --num_qo_heads 64 \ |
| 72 | + --num_kv_heads 8 \ |
| 73 | + --head_dim_qk 128 \ |
| 74 | + --head_dim_vo 128 \ |
| 75 | + --random_actual_seq_len \ |
| 76 | + --verbose \ |
| 77 | + --refcheck |
| 78 | + |
| 79 | +# FP8 GEMM |
| 80 | +python3 flashinfer_benchmark.py \ |
| 81 | + --routine gemm_fp8_nt_groupwise \ |
| 82 | + --m 8192 \ |
| 83 | + --n 4096 \ |
| 84 | + --k 16384 \ |
| 85 | + --mma_sm 2 \ |
| 86 | + --refcheck \ |
| 87 | + -vv |
| 88 | + |
| 89 | +# Group FP8 GEMM |
| 90 | +python3 flashinfer_benchmark.py \ |
| 91 | + --routine group_gemm_fp8_nt_groupwise \ |
| 92 | + --m 8192 \ |
| 93 | + --n 4096 \ |
| 94 | + --k 16384 \ |
| 95 | + --mma_sm 2 \ |
| 96 | + --group_size 2 \ |
| 97 | + --no_cuda_graph \ |
| 98 | + --scale_major_mode K \ |
| 99 | + --refcheck \ |
| 100 | + -vv |
| 101 | +``` |
| 102 | + |
| 103 | +### Batch Testing |
| 104 | + |
| 105 | +Run multiple tests from a file and save results: |
| 106 | +```bash |
| 107 | +python3 flashinfer_benchmark.py --testlist samples/sample_testlist.txt --output_path sample_testlist_output.csv |
| 108 | +``` |
| 109 | + |
| 110 | +The output CSV will contain detailed metrics including: |
| 111 | +- Median execution time |
| 112 | +- Standard deviation |
| 113 | +- TFLOPS/sec |
| 114 | +- Memory throughput (TB/sec) |
| 115 | + |
| 116 | +## Command Line Arguments |
| 117 | +### General Flags |
| 118 | +| Flag | Description | |
| 119 | +|--------------------------|-------------------------------------------------------------------------------------------------------------| |
| 120 | +| `--routine` | Test routine to run: `BatchDecodeWithPagedKVCacheWrapper`, `BatchPrefillWithPagedKVCacheWrapper`, or `BatchPrefillWithRaggedKVCacheWrapper`, `gemm_fp8_nt_groupwise`, `group_gemm_fp8_nt_groupwise`| |
| 121 | +| `--num_iters` | Number of iterations for performance measurement | |
| 122 | +| `--dry_run_iters` | Number of warmup iterations | |
| 123 | +| `--no_cuda_graph` | Disable CUDA graph to execute kernels outside of the graph. |
| 124 | +| `--allow_output_mismatch`| Continue testing even if outputs don't match between backends | |
| 125 | +| `--refcheck` | Verify outputs match between different backends | |
| 126 | +| `--random_seed` | Random seed for reproducibility | |
| 127 | +| `--output_path` | Path to save CSV results | |
| 128 | +| `--verbose`, `-v` | Print additional information | |
| 129 | + |
| 130 | +### Attention Flags |
| 131 | +| Flag | Description | |
| 132 | +|--------------------------|-------------------------------------------------------------------------------------------------------------| |
| 133 | +| `--backends` | List of backends to test: fa2, fa2_tc, fa3, cudnn, cutlass, trtllm | |
| 134 | +| `--page_size` | Page size for paged attention. Required for paged attention tests. | |
| 135 | +| `--batch_size` | Number of sequences to process in parallel | |
| 136 | +| `--s_qo` | Query/output sequence length. Should be 1 for decode tests. | |
| 137 | +| `--s_kv` | Key/value sequence length (context length) | |
| 138 | +| `--num_qo_heads` | Number of query/output attention heads | |
| 139 | +| `--num_kv_heads` | Number of key/value attention heads |
| 140 | +| `--head_dim_qk` | Head dimension for Q/K. Must be 128 or 192. | |
| 141 | +| `--head_dim_vo` | Head dimension for V/O. Usually equals head_dim_qk. | |
| 142 | +| `--q_dtype` | Data type for the query tensor. Default: bfloat16. Currently only bfloat16 is supported. | |
| 143 | +| `--kv_dtype` | Data type for the key and value tensors. Default: bfloat16. Currently only bfloat16 is supported. | |
| 144 | +| `--causal` | Use causal attention masking (prefill only) | |
| 145 | +| `--random_actual_seq_len`| Use random sequence lengths up to max length. If False, use max length. | |
| 146 | + |
| 147 | +### GEMM Flags |
| 148 | +| Flag | Description | |
| 149 | +|--------------------------|-------------------------------------------------------------------------------------------------------------| |
| 150 | +| `--m` | Number of rows of matrix A and output matrix (GEMM M dimension) | |
| 151 | +| `--n` | Number of columns of matrix B and output matrix (GEMM N dimension) | |
| 152 | +| `--k` | Number of columns of matrix A / rows of matrix B (GEMM K dimension) | |
| 153 | +| `--tile_size` | Tile size for the GEMM operation (affects performance and scaling) | |
| 154 | +| `--group_size` | Number of groups for group GEMM (batching multiple GEMMs together) | |
| 155 | +| `--scale_major_mode` | Layout for FP8 scaling: `MN` (per output tile) or `K` (per input tile) | |
| 156 | +| `--out_dtype` | Output data type: `bfloat16` or `float16` | |
| 157 | +| `--mma_sm` | Number of SMs to use for the MMA operation (1 or 2) | |
| 158 | + |
| 159 | +## Tester Attention Backend Support Matrix |
| 160 | +The following support surface applies to attention operations in `flashinfer_benchmark.py` |
| 161 | +| Backend | Decode Paged | Prefill Paged | Prefill Ragged | FP8 | Notes | |
| 162 | +|----------|-------------|---------------|----------------|------|------------------------------------------| |
| 163 | +| fa2 | ✓ | ✓ | ✓ | ✗ | Does not support GQA ratio of 5 | |
| 164 | +| fa2_tc | ✓ | ✗ | ✗ | ✗ | Uses tensor cores | |
| 165 | +| fa3 | ✗ | ✓ | ✓ | ✗ | Hopper Only | |
| 166 | +| cudnn | ✓ | ✓* | ✓* | ✗ | *Requires specific head dims (192 or 128) | |
| 167 | +| cutlass | ✗ | ✗ | ✓ | ✗ | | |
| 168 | +| trtllm | ✓ | ✗ | ✗ | ✗ | | |
| 169 | + |
| 170 | +Notes: |
| 171 | +- Currently only support bfloat16 attention only. |
| 172 | +- CUDA graph support is only stable with BatchDecodeWithPagedKVCacheWrapper. For BatchPrefillWithPagedKVCacheWrapper and BatchPrefillWithRaggedKVCacheWrapper, it is recommended that `--no_cuda_graph` is used. |
| 173 | +- cudnn, cutlass, and trtllm backends are supported on [CUDA Compute Capability 10.0 GPUs](https://developer.nvidia.com/cuda-gpus) only. |
| 174 | +- fa3 is supported on [CUDA Compute Capability 9.0 GPUs](https://developer.nvidia.com/cuda-gpus) only. |
| 175 | + |
| 176 | +## Example Outputs |
| 177 | +```bash |
| 178 | +$ python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 cudnn --page_size 16 --batch_size 16 --s_qo 1024 --s_kv 1024 --num_qo_heads 8 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len --refcheck --causal --no_cuda_graph -vv |
| 179 | +[INFO] args = Namespace(routine='BatchPrefillWithPagedKVCacheWrapper', backends=['fa2', 'cudnn'], page_size=16, batch_size=16, s_qo=1024, s_kv=1024, num_qo_heads=8, num_kv_heads=8, head_dim_qk=128, head_dim_vo=128, q_dtype='bfloat16', kv_dtype='bfloat16', causal=True, num_iters=30, dry_run_iters=5, no_cuda_graph=True, random_actual_seq_len=True, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None) |
| 180 | +[INFO] FlashInfer version: 0.2.8 |
| 181 | +[INFO] Running testBatchPrefillWithPagedKVCacheWrapper |
| 182 | +[VVERBOSE] gpu_name = 'NVIDIA_B200' |
| 183 | +[VERBOSE] Average actual seq len: 327 |
| 184 | +[VVERBOSE] actual_seq_lens_q.flatten() = tensor([103, 436, 861, 271, 107, 72, 701, 21, 615, 122, 467, 215, 331, 459, |
| 185 | + 88, 373], dtype=torch.int32) |
| 186 | +[VVERBOSE] q.shape = torch.Size([5242, 8, 128]) |
| 187 | +[VVERBOSE] num_pages_per_seq = 64 |
| 188 | +[VVERBOSE] total_num_pages = 1024 |
| 189 | +[VVERBOSE] kv_cache.shape = torch.Size([1024, 2, 8, 16, 128]) |
| 190 | +[VVERBOSE] kv_cache.stride() = (32768, 16384, 128, 1024, 1) |
| 191 | +[VVERBOSE] block_tables.shape = torch.Size([16, 64]) |
| 192 | +[VVERBOSE] qo_indptr.shape = torch.Size([17]) |
| 193 | +[VVERBOSE] qo_indptr.dtype = torch.int32 |
| 194 | +[VVERBOSE] kv_indptr.shape = torch.Size([17]) |
| 195 | +[VVERBOSE] kv_indices.shape = torch.Size([335]) |
| 196 | +[VVERBOSE] kv_last_page_len.shape = torch.Size([16]) |
| 197 | +[VVERBOSE] scale = 0.08838834764831843 |
| 198 | +[PERF] fa2 :: median time 0.094 ms; std 0.022 ms; achieved tflops 57.960 TFLOPs/sec; achieved tb_per_sec 0.459 TB/sec |
| 199 | +[PERF] cudnn :: median time 0.096 ms; std 0.021 ms; achieved tflops 56.661 TFLOPs/sec; achieved tb_per_sec 0.449 TB/sec |
| 200 | + |
| 201 | +$ python3 flashinfer_benchmark.py --routine BatchPrefillWithRaggedKVCacheWrapper --backends fa2 cudnn cutlass --batch_size 16 --s_qo 1024 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_qk 192 --head_dim_vo 128 --refcheck --causal --no_cuda_graph -vv |
| 202 | +[INFO] args = Namespace(routine='BatchPrefillWithRaggedKVCacheWrapper', backends=['fa2', 'cudnn', 'cutlass'], page_size=0, batch_size=16, s_qo=1024, s_kv=1024, num_qo_heads=128, num_kv_heads=128, head_dim_qk=192, head_dim_vo=128, q_dtype='bfloat16', kv_dtype='bfloat16', causal=True, num_iters=30, dry_run_iters=5, no_cuda_graph=True, random_actual_seq_len=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=True, vverbose=True, output_path=None) |
| 203 | +[INFO] FlashInfer version: 0.2.8 |
| 204 | +[INFO] Running testBatchPrefillWithRaggedKVCacheWrapper |
| 205 | +[VVERBOSE] gpu_name = 'NVIDIA_B200' |
| 206 | +[VERBOSE] Average actual seq len: 1024 |
| 207 | +[VVERBOSE] actual_seq_lens_q.flatten() = tensor([1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, |
| 208 | + 1024, 1024, 1024, 1024], dtype=torch.int32) |
| 209 | +[VVERBOSE] q.shape = torch.Size([16384, 128, 192]) |
| 210 | +[VVERBOSE] k.shape = torch.Size([16384, 128, 192]) |
| 211 | +[VVERBOSE] v.shape = torch.Size([16384, 128, 128]) |
| 212 | +[VVERBOSE] qo_indptr.shape = torch.Size([17]) |
| 213 | +[VVERBOSE] kv_indptr.shape = torch.Size([17]) |
| 214 | +[VVERBOSE] scale = 0.07216878364870323 |
| 215 | +[PERF] fa2 :: median time 2.197 ms; std 0.011 ms; achieved tflops 312.787 TFLOPs/sec; achieved tb_per_sec 1.222 TB/sec |
| 216 | +[PERF] cudnn :: median time 1.008 ms; std 0.014 ms; achieved tflops 681.979 TFLOPs/sec; achieved tb_per_sec 2.664 TB/sec |
| 217 | +[PERF] cutlass :: median time 1.453 ms; std 0.021 ms; achieved tflops 473.035 TFLOPs/sec; achieved tb_per_sec 1.848 TB/sec |
| 218 | + |
| 219 | +$ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc trtllm cudnn --page_size 16 --batch_size 32 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --refcheck -vv |
| 220 | +[INFO] args = Namespace(routine='BatchDecodeWithPagedKVCacheWrapper', backends=['fa2', 'fa2_tc', 'trtllm', 'cudnn'], page_size=16, batch_size=32, s_qo=1, s_kv=8192, num_qo_heads=64, num_kv_heads=8, head_dim_qk=128, head_dim_vo=128, q_dtype='bfloat16', kv_dtype='bfloat16', causal=False, num_iters=30, dry_run_iters=5, no_cuda_graph=False, random_actual_seq_len=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None) |
| 221 | +[INFO] FlashInfer version: 0.2.8 |
| 222 | +[INFO] Running testBatchDecodeWithPagedKVCacheWrapper |
| 223 | +[VVERBOSE] gpu_name = 'NVIDIA_B200' |
| 224 | +[VERBOSE] Average actual seq len: 8192 |
| 225 | +[VVERBOSE] actual_seq_lens_kv.flatten() = tensor([8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, |
| 226 | + 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, |
| 227 | + 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192], device='cuda:0', |
| 228 | + dtype=torch.int32) |
| 229 | +[VVERBOSE] q.shape = torch.Size([32, 64, 128]) |
| 230 | +[VVERBOSE] num_pages_per_seq = 512 |
| 231 | +[VVERBOSE] total_num_pages = 16384 |
| 232 | +[VVERBOSE] kv_cache.shape = torch.Size([16384, 2, 8, 16, 128]) |
| 233 | +[VVERBOSE] kv_cache.stride() = (32768, 16384, 128, 1024, 1) |
| 234 | +[VVERBOSE] block_tables.shape = torch.Size([32, 512]) |
| 235 | +[VVERBOSE] kv_indptr.shape = torch.Size([33]) |
| 236 | +[VVERBOSE] kv_indices.shape = torch.Size([16384]) |
| 237 | +[VVERBOSE] kv_last_page_len.shape = torch.Size([32]) |
| 238 | +[VVERBOSE] scale = 0.08838834764831843 |
| 239 | +[PERF] fa2 :: median time 0.712 ms; std 0.000 ms; achieved tflops 12.061 TFLOPs/sec; achieved tb_per_sec 1.509 TB/sec |
| 240 | +[PERF] fa2_tc :: median time 0.173 ms; std 0.001 ms; achieved tflops 49.779 TFLOPs/sec; achieved tb_per_sec 6.228 TB/sec |
| 241 | +[PERF] trtllm :: median time 0.155 ms; std 0.000 ms; achieved tflops 55.344 TFLOPs/sec; achieved tb_per_sec 6.925 TB/sec |
| 242 | +[PERF] cudnn :: median time 0.253 ms; std 0.000 ms; achieved tflops 33.964 TFLOPs/sec; achieved tb_per_sec 4.250 TB/sec |
| 243 | +``` |
0 commit comments