Skip to content

Commit fdc8338

Browse files
authored
misc: Various Updates to Attention Microbenchmark Suite (#1891)
<!-- .github/pull_request_template.md --> ## 📌 Description Current PR brings a host of updates to the the attention microbenchmark suites in `flashinfer_benchmark.py` * `testBatchPrefillWithPagedKVCacheWrapper`: * `trtllm-gen-native` that calls `flashinfer.prefill.trtllm_batch_context_with_kv_cache` is added as a backend. Disabled for batch size 1 due to various errors. An issue will be filed to track the error. * `trtllm-gen` and `trtllm-gen-native` backends can now be benchmarked for FP8 * `trtllm-gen` and `trtllm-gen-native` are now disabled for `causal=False`. Previous behavior was silently ignoring the flag and running `causal=True` * `testBatchPrefillWithRaggedKVCacheWrapper`: * `trtllm-gen-native` that calls `flashinfer.prefill.trtllm_ragged_attention_deepseek` is added as a backend. Disabled for batch size 1 due to various errors. An issue will be filed to track the error. * `testBatchMLAPagedAttentionWrapper`: * `cutlass` backend has been added as a backend that can be benchmarked * Misc minor fixes such as correct refcheck failure messages Examples: ``` # python3 flashinfer_benchmark.py --routine BatchMLAPagedAttentionWrapper --backends trtllm-gen-native fa2 cutlass --page_size 32 --batch_size 16 --s_qo 1 --s_kv 8192 --num_qo_heads 128 --num_kv_heads 128 --head_dim_ckv 512 --head_dim_kpe 64 --random_actual_seq_len --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 [PERF] trtllm-gen-nati:: median time 0.031 ms; std 0.000 ms; achieved tflops 553.684 TFLOPs/sec; achieved tb_per_sec 4.960 TB/sec [PERF] fa2 :: median time 0.091 ms; std 0.001 ms; achieved tflops 190.364 TFLOPs/sec; achieved tb_per_sec 1.705 TB/sec [PERF] cutlass :: median time 0.221 ms; std 0.000 ms; achieved tflops 78.342 TFLOPs/sec; achieved tb_per_sec 0.702 TB/sec # python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 16 --s_qo 8192 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len --causal --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 [PERF] fa2 :: median time 17.342 ms; std 0.011 ms; achieved tflops 397.579 TFLOPs/sec; achieved tb_per_sec 0.161 TB/sec [PERF] cudnn :: median time 6.230 ms; std 0.032 ms; achieved tflops 1106.685 TFLOPs/sec; achieved tb_per_sec 0.449 TB/sec [PERF] trtllm-gen :: median time 7.181 ms; std 0.040 ms; achieved tflops 960.135 TFLOPs/sec; achieved tb_per_sec 0.390 TB/sec [PERF] trtllm-gen-nati:: median time 6.453 ms; std 0.012 ms; achieved tflops 1068.434 TFLOPs/sec; achieved tb_per_sec 0.434 TB/sec # python3 flashinfer_benchmark.py --routine BatchPrefillWithRaggedKVCacheWrapper --backends fa2 cutlass cudnn trtllm-gen-native --batch_size 16 --s_qo 8192 --s_kv 8192 --num_qo_heads 128 --num_kv_heads 128 --head_dim_qk 192 --head_dim_vo 128 --random_actual_seq_len --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 [PERF] fa2 :: median time 39.797 ms; std 0.023 ms; achieved tflops 433.137 TFLOPs/sec; achieved tb_per_sec 0.312 TB/sec [PERF] cutlass :: median time 18.509 ms; std 0.348 ms; achieved tflops 931.281 TFLOPs/sec; achieved tb_per_sec 0.672 TB/sec [PERF] cudnn :: median time 14.778 ms; std 0.336 ms; achieved tflops 1166.391 TFLOPs/sec; achieved tb_per_sec 0.841 TB/sec [PERF] trtllm-gen-nati:: median time 14.339 ms; std 0.291 ms; achieved tflops 1202.155 TFLOPs/sec; achieved tb_per_sec 0.867 TB/sec ``` **No changes to library code** <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent d910f9a commit fdc8338

File tree

3 files changed

+125
-36
lines changed

3 files changed

+125
-36
lines changed

benchmarks/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,9 @@ Legend:
217217
| Routine | 7.5 | 8.0 | 8.6 | 8.9 | 9.0 | 10.0 | 10.3 | 12.0 |
218218
|---------|-----|-----|-----|-----|-----|-------|-------|-------|
219219
| **BatchDecodeWithPagedKVCacheWrapper** | fa2 | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-gen-native | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-gen-native | fa2, fa2_tc, cudnn |
220-
| **BatchPrefillWithPagedKVCacheWrapper** | | fa2, cudnn | fa2, cudnn | fa2, cudnn | fa2, fa3, cudnn | fa2, cudnn, trtllm-gen | fa2, cudnn, trtllm-gen | fa2, cudnn |
221-
| **BatchPrefillWithRaggedKVCacheWrapper** | | fa2, cudnn | fa2, cudnn | fa2, cudnn | fa2, fa3, cudnn | fa2, cudnn, cutlass | fa2, cudnn, cutlass | fa2, cudnn |
222-
| **BatchMLAPagedAttentionWrapper** | | fa2 | fa2 | fa2 | fa2, fa3 | fa2, trtllm-gen-native | fa2, trtllm-gen-native | fa2 |
220+
| **BatchPrefillWithPagedKVCacheWrapper** | | fa2, cudnn | fa2, cudnn | fa2, cudnn | fa2, fa3, cudnn | fa2, cudnn, trtllm-gen, trtllm-gen-native | fa2, cudnn, trtllm-gen, trtllm-gen-native | fa2, cudnn |
221+
| **BatchPrefillWithRaggedKVCacheWrapper** | | fa2, cudnn | fa2, cudnn | fa2, cudnn | fa2, fa3, cudnn | fa2, cudnn, cutlass, trtllm-gen-native | fa2, cudnn, cutlass, trtllm-gen-native | fa2, cudnn |
222+
| **BatchMLAPagedAttentionWrapper** | | fa2 | fa2 | fa2 | fa2, fa3 | fa2, cutlass, trtllm-gen-native | fa2, cutlass, trtllm-gen-native | fa2 |
223223
| **gemm_fp8_nt_groupwise** | | | | | | cutlass | cutlass | |
224224
| **group_gemm_fp8_nt_groupwise** | | | | | | cutlass | cutlass | |
225225
| **bmm_fp8** | | | | cudnn, cublas | cudnn, cublas | cudnn, cublas, cutlass | cudnn, cublas, cutlass | cudnn, cublas |

benchmarks/routines/attention.py

Lines changed: 116 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ def run_backend_wrapper(backend):
545545
) = is_close_stats(reference_output, tested_outputs[i], rtol, atol)
546546
if num_different_elements > 0:
547547
print(
548-
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}: "
548+
f"[ERROR] Output tensor mismatch between backends fa2 and {tested_backends[i]}: "
549549
f"{num_different_elements} / {num_elements} ({num_different_elements_percentage:.2f}%) elements are different"
550550
)
551551
if not args.allow_output_mismatch:
@@ -689,14 +689,22 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
689689

690690
if "trtllm-gen" in backends:
691691
remove_trtllm = False
692-
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
693-
torch.float8_e4m3fn,
694-
torch.float8_e5m2,
695-
]:
696-
print("[INFO] trtllm-gen backend does not support FP8. Skipping.")
692+
if not causal:
693+
print("[INFO] trtllm-gen backend currently requires causal = True")
697694
remove_trtllm = True
698695
if remove_trtllm:
699696
backends.remove("trtllm-gen")
697+
if "trtllm-gen-native" in backends:
698+
remove_trtllm_native = False
699+
if batch_size == 1:
700+
# TO-DO: trtllm-gen-native hits IMA on batch size 1. Investigate and fix.
701+
print("[INFO] trtllm-gen-native backend currently requires batch size > 1")
702+
remove_trtllm_native = True
703+
if not causal:
704+
print("[INFO] trtllm-gen-native backend currently requires causal = True")
705+
remove_trtllm_native = True
706+
if remove_trtllm_native:
707+
backends.remove("trtllm-gen-native")
700708

701709
if "cutlass" in backends:
702710
print("[INFO] CUTLASS backend does not support prefill. Skipping.")
@@ -1006,7 +1014,7 @@ def run_backend_wrapper(backend):
10061014
) = is_close_stats(reference_output, tested_outputs[i], rtol, atol)
10071015
if num_different_elements > 0:
10081016
print(
1009-
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}: "
1017+
f"[ERROR] Output tensor mismatch between backends fa2 and {tested_backends[i]}: "
10101018
f"{num_different_elements} / {num_elements} ({num_different_elements_percentage:.2f}%) elements are different"
10111019
)
10121020
if not args.allow_output_mismatch:
@@ -1129,6 +1137,13 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args):
11291137

11301138
backends = filter_backends_by_compute_capability(backends, args.routine, device)
11311139
# Check for backend-specific constraints
1140+
if "fa2" in backends:
1141+
remove_fa2 = False
1142+
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
1143+
print("[INFO] FA2 backend does not support FP8. Skipping.")
1144+
remove_fa2 = True
1145+
if remove_fa2:
1146+
backends.remove("fa2")
11321147
if "cudnn" in backends:
11331148
remove_cudnn = False
11341149
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
@@ -1161,6 +1176,25 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args):
11611176
remove_trtllm = True
11621177
if remove_trtllm:
11631178
backends.remove("trtllm-gen")
1179+
if "trtllm-gen-native" in backends:
1180+
remove_trtllm_native = False
1181+
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
1182+
torch.float8_e4m3fn,
1183+
torch.float8_e5m2,
1184+
]:
1185+
print("[INFO] trtllm-gen-native backend does not support FP8. Skipping.")
1186+
remove_trtllm_native = True
1187+
if batch_size == 1:
1188+
# TO-DO: trtllm-gen-native hits IMA on batch size 1. Investigate and fix.
1189+
print("[INFO] trtllm-gen-native backend currently requires batch size > 1")
1190+
remove_trtllm_native = True
1191+
if not (head_dim_qk == 192 and head_dim_vo == 128):
1192+
print(
1193+
"[INFO] trtllm-gen-native backend requires head_dim_qk == 192 and head_dim_vo == 128"
1194+
)
1195+
remove_trtllm_native = True
1196+
if remove_trtllm_native:
1197+
backends.remove("trtllm-gen-native")
11641198

11651199
if len(backends) == 0:
11661200
print("[ERROR] No backends to test. Exiting.")
@@ -1372,6 +1406,26 @@ def run_backend_wrapper(backend):
13721406
batch_offsets_stats=batch_offsets_stats,
13731407
is_cuda_graph_compatible=True,
13741408
)[0]
1409+
elif backend == "trtllm-gen-native":
1410+
return flashinfer.prefill.trtllm_ragged_attention_deepseek(
1411+
query=q,
1412+
key=k,
1413+
value=v,
1414+
workspace_buffer=workspace_buffer,
1415+
seq_lens=actual_seq_lens_kv_device,
1416+
max_q_len=s_qo,
1417+
max_kv_len=s_kv,
1418+
bmm1_scale=scale,
1419+
bmm2_scale=1.0,
1420+
o_sf_scale=-1,
1421+
batch_size=batch_size,
1422+
window_left=-1,
1423+
cum_seq_lens_q=qo_indptr,
1424+
cum_seq_lens_kv=kv_indptr,
1425+
enable_pdl=False,
1426+
is_causal=causal,
1427+
return_lse=True,
1428+
)[0]
13751429
else:
13761430
print(f"[ERROR] Backend {backend} not supported")
13771431
return res
@@ -1416,7 +1470,7 @@ def run_backend_wrapper(backend):
14161470
) = is_close_stats(reference_output, tested_outputs[i], rtol, atol)
14171471
if num_different_elements > 0:
14181472
print(
1419-
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}: "
1473+
f"[ERROR] Output tensor mismatch between backends fa2 and {tested_backends[i]}: "
14201474
f"{num_different_elements} / {num_elements} ({num_different_elements_percentage:.2f}%) elements are different"
14211475
)
14221476
if not args.allow_output_mismatch:
@@ -1484,7 +1538,7 @@ def run_backend_wrapper(backend):
14841538
def testBatchMLAPagedAttentionWrapper(args):
14851539
"""
14861540
Test BatchMLAPagedAttentionWrapper and equivalent APIs.
1487-
Supports fa2. and trtllm-gen-native.
1541+
Supports fa2, fa3, cutlass, and trtllm-gen-native.
14881542
14891543
This test:
14901544
1. Creates paged query and key-value cache tensors
@@ -1565,6 +1619,30 @@ def testBatchMLAPagedAttentionWrapper(args):
15651619
remove_fa3 = True
15661620
if remove_fa3:
15671621
backends.remove("fa3")
1622+
if "cutlass" in backends:
1623+
remove_cutlass = False
1624+
if page_size not in [32, 64]:
1625+
print(
1626+
"[INFO] Cutlass MLA backend only supports page size 32 or 64. Skipping."
1627+
)
1628+
remove_cutlass = True
1629+
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
1630+
torch.float8_e4m3fn,
1631+
torch.float8_e5m2,
1632+
]:
1633+
print("[INFO] Cutlass MLA backend does not support FP8. Skipping.")
1634+
remove_cutlass = True
1635+
if remove_cutlass:
1636+
backends.remove("cutlass")
1637+
if "trtllm-gen-native" in backends:
1638+
remove_trtllm_native = False
1639+
if page_size not in [32, 64]:
1640+
print(
1641+
"[INFO] trtllm-gen-native backend only supports page size 32 or 64. Skipping."
1642+
)
1643+
remove_trtllm_native = True
1644+
if remove_trtllm_native:
1645+
backends.remove("trtllm-gen-native")
15681646
if len(backends) == 0:
15691647
print("[ERROR] No backends to test. Exiting.")
15701648
return res
@@ -1629,7 +1707,7 @@ def testBatchMLAPagedAttentionWrapper(args):
16291707
page_size,
16301708
head_dim_kpe,
16311709
)
1632-
kpe_cache = torch.randn(size=kpe_cache_shape, dtype=q_init_dtype, device=device)
1710+
kpe_cache = torch.randn(size=kpe_cache_shape, dtype=kv_init_dtype, device=device)
16331711
kv_cache = torch.cat([ckv_cache, kpe_cache], dim=2)
16341712

16351713
qo_indptr = torch.arange(0, batch_size + 1, device=device).int()
@@ -1657,7 +1735,7 @@ def testBatchMLAPagedAttentionWrapper(args):
16571735
device=device,
16581736
)
16591737

1660-
sm_scale = 1.0 / ((head_dim_ckv + head_dim_kpe) ** 0.5)
1738+
sm_scale = 1.0 / ((128 + 64) ** 0.5) # For DeepSeek-R1
16611739
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
16621740

16631741
if args.verbose >= 2:
@@ -1674,7 +1752,7 @@ def testBatchMLAPagedAttentionWrapper(args):
16741752
# Create wrapper
16751753
backend_wrappers = {}
16761754
for backend in backends:
1677-
if backend in ["fa2", "fa3"]:
1755+
if backend in ["fa2", "fa3", "cutlass"]:
16781756
backend_wrappers[backend] = flashinfer.mla.BatchMLAPagedAttentionWrapper(
16791757
float_workspace_buffer=workspace_buffer,
16801758
use_cuda_graph=is_cuda_graph_compatible,
@@ -1684,20 +1762,21 @@ def testBatchMLAPagedAttentionWrapper(args):
16841762
kv_len_arr=actual_seq_lens_kv,
16851763
backend=backend,
16861764
)
1687-
backend_wrappers[backend].plan(
1688-
qo_indptr=qo_indptr,
1689-
kv_indptr=kv_indptr,
1690-
kv_indices=kv_indices,
1691-
kv_len_arr=actual_seq_lens_kv,
1692-
num_heads=num_qo_heads,
1693-
head_dim_ckv=head_dim_ckv,
1694-
head_dim_kpe=head_dim_kpe,
1695-
page_size=page_size,
1696-
causal=causal,
1697-
sm_scale=sm_scale,
1698-
q_data_type=q_dtype,
1699-
kv_data_type=kv_dtype,
1700-
)
1765+
if backend != "cutlass":
1766+
backend_wrappers[backend].plan(
1767+
qo_indptr=qo_indptr,
1768+
kv_indptr=kv_indptr,
1769+
kv_indices=kv_indices,
1770+
kv_len_arr=actual_seq_lens_kv,
1771+
num_heads=num_qo_heads,
1772+
head_dim_ckv=head_dim_ckv,
1773+
head_dim_kpe=head_dim_kpe,
1774+
page_size=page_size,
1775+
causal=causal,
1776+
sm_scale=sm_scale,
1777+
q_data_type=q_dtype,
1778+
kv_data_type=kv_dtype,
1779+
)
17011780

17021781
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
17031782
q = q.to(q_dtype)
@@ -1713,6 +1792,16 @@ def run_backend_wrapper(backend):
17131792
return backend_wrappers[backend].run(
17141793
q_nope, q_pe, ckv_cache, kpe_cache, return_lse=False
17151794
)
1795+
elif backend == "cutlass":
1796+
return backend_wrappers[backend].run(
1797+
q_nope,
1798+
q_pe,
1799+
ckv_cache,
1800+
kpe_cache,
1801+
kv_len=actual_seq_lens_kv.flatten(),
1802+
page_table=block_tables,
1803+
return_lse=False,
1804+
)
17161805
if backend == "trtllm-gen-native":
17171806
return flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
17181807
query=q.unsqueeze(1),
@@ -1767,7 +1856,7 @@ def run_backend_wrapper(backend):
17671856
) = is_close_stats(reference_output, tested_outputs[i], rtol, atol)
17681857
if num_different_elements > 0:
17691858
print(
1770-
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}: "
1859+
f"[ERROR] Output tensor mismatch between backends fa2 and {tested_backends[i]}: "
17711860
f"{num_different_elements} / {num_elements} ({num_different_elements_percentage:.2f}%) elements are different"
17721861
)
17731862
if not args.allow_output_mismatch:

benchmarks/routines/flashinfer_benchmark_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ def dtype_str_to_torch_dtype(dtype_str):
177177
"8.6": ["fa2", "cudnn"],
178178
"8.9": ["fa2", "cudnn"],
179179
"9.0": ["fa2", "fa3", "cudnn"],
180-
"10.0": ["fa2", "cudnn", "trtllm-gen"],
181-
"10.3": ["fa2", "cudnn", "trtllm-gen"],
180+
"10.0": ["fa2", "cudnn", "trtllm-gen", "trtllm-gen-native"],
181+
"10.3": ["fa2", "cudnn", "trtllm-gen", "trtllm-gen-native"],
182182
"12.0": ["fa2", "cudnn"],
183183
},
184184
"BatchPrefillWithRaggedKVCacheWrapper": {
@@ -187,8 +187,8 @@ def dtype_str_to_torch_dtype(dtype_str):
187187
"8.6": ["fa2", "cudnn"],
188188
"8.9": ["fa2", "cudnn"],
189189
"9.0": ["fa2", "fa3", "cudnn"],
190-
"10.0": ["fa2", "cudnn", "cutlass"],
191-
"10.3": ["fa2", "cudnn", "cutlass"],
190+
"10.0": ["fa2", "cudnn", "cutlass", "trtllm-gen-native"],
191+
"10.3": ["fa2", "cudnn", "cutlass", "trtllm-gen-native"],
192192
"12.0": ["fa2", "cudnn"],
193193
},
194194
"BatchMLAPagedAttentionWrapper": {
@@ -197,8 +197,8 @@ def dtype_str_to_torch_dtype(dtype_str):
197197
"8.6": ["fa2"],
198198
"8.9": ["fa2"],
199199
"9.0": ["fa2", "fa3"],
200-
"10.0": ["fa2", "trtllm-gen-native"],
201-
"10.3": ["fa2", "trtllm-gen-native"],
200+
"10.0": ["fa2", "cutlass", "trtllm-gen-native"],
201+
"10.3": ["fa2", "cutlass", "trtllm-gen-native"],
202202
"12.0": ["fa2"],
203203
},
204204
# GEMM

0 commit comments

Comments
 (0)