You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
misc: Various Updates to Attention Microbenchmark Suite (#1891)
<!-- .github/pull_request_template.md -->
## 📌 Description
Current PR brings a host of updates to the the attention microbenchmark
suites in `flashinfer_benchmark.py`
* `testBatchPrefillWithPagedKVCacheWrapper`:
* `trtllm-gen-native` that calls
`flashinfer.prefill.trtllm_batch_context_with_kv_cache` is added as a
backend. Disabled for batch size 1 due to various errors. An issue will
be filed to track the error.
* `trtllm-gen` and `trtllm-gen-native` backends can now be benchmarked
for FP8
* `trtllm-gen` and `trtllm-gen-native` are now disabled for
`causal=False`. Previous behavior was silently ignoring the flag and
running `causal=True`
* `testBatchPrefillWithRaggedKVCacheWrapper`:
* `trtllm-gen-native` that calls
`flashinfer.prefill.trtllm_ragged_attention_deepseek` is added as a
backend. Disabled for batch size 1 due to various errors. An issue will
be filed to track the error.
* `testBatchMLAPagedAttentionWrapper`:
* `cutlass` backend has been added as a backend that can be benchmarked
* Misc minor fixes such as correct refcheck failure messages
Examples:
```
# python3 flashinfer_benchmark.py --routine BatchMLAPagedAttentionWrapper --backends trtllm-gen-native fa2 cutlass --page_size 32 --batch_size 16 --s_qo 1 --s_kv 8192 --num_qo_heads 128 --num_kv_heads 128 --head_dim_ckv 512 --head_dim_kpe 64 --random_actual_seq_len --refcheck --q_dtype bfloat16 --kv_dtype bfloat16
[PERF] trtllm-gen-nati:: median time 0.031 ms; std 0.000 ms; achieved tflops 553.684 TFLOPs/sec; achieved tb_per_sec 4.960 TB/sec
[PERF] fa2 :: median time 0.091 ms; std 0.001 ms; achieved tflops 190.364 TFLOPs/sec; achieved tb_per_sec 1.705 TB/sec
[PERF] cutlass :: median time 0.221 ms; std 0.000 ms; achieved tflops 78.342 TFLOPs/sec; achieved tb_per_sec 0.702 TB/sec
# python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 16 --s_qo 8192 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len --causal --refcheck --q_dtype bfloat16 --kv_dtype bfloat16
[PERF] fa2 :: median time 17.342 ms; std 0.011 ms; achieved tflops 397.579 TFLOPs/sec; achieved tb_per_sec 0.161 TB/sec
[PERF] cudnn :: median time 6.230 ms; std 0.032 ms; achieved tflops 1106.685 TFLOPs/sec; achieved tb_per_sec 0.449 TB/sec
[PERF] trtllm-gen :: median time 7.181 ms; std 0.040 ms; achieved tflops 960.135 TFLOPs/sec; achieved tb_per_sec 0.390 TB/sec
[PERF] trtllm-gen-nati:: median time 6.453 ms; std 0.012 ms; achieved tflops 1068.434 TFLOPs/sec; achieved tb_per_sec 0.434 TB/sec
# python3 flashinfer_benchmark.py --routine BatchPrefillWithRaggedKVCacheWrapper --backends fa2 cutlass cudnn trtllm-gen-native --batch_size 16 --s_qo 8192 --s_kv 8192 --num_qo_heads 128 --num_kv_heads 128 --head_dim_qk 192 --head_dim_vo 128 --random_actual_seq_len --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16
[PERF] fa2 :: median time 39.797 ms; std 0.023 ms; achieved tflops 433.137 TFLOPs/sec; achieved tb_per_sec 0.312 TB/sec
[PERF] cutlass :: median time 18.509 ms; std 0.348 ms; achieved tflops 931.281 TFLOPs/sec; achieved tb_per_sec 0.672 TB/sec
[PERF] cudnn :: median time 14.778 ms; std 0.336 ms; achieved tflops 1166.391 TFLOPs/sec; achieved tb_per_sec 0.841 TB/sec
[PERF] trtllm-gen-nati:: median time 14.339 ms; std 0.291 ms; achieved tflops 1202.155 TFLOPs/sec; achieved tb_per_sec 0.867 TB/sec
```
**No changes to library code**
<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->
## 🔍 Related Issues
<!-- Link any related issues here -->
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).
## Reviewer Notes
<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
0 commit comments