Skip to content

Commit 50319b2

Browse files
authored
feat: Add compute capability checks to flashinfer_benchmark (#1756)
<!-- .github/pull_request_template.md --> ## 📌 Description Current PR: * Adds a compute capability check function (`filter_backends_by_compute_capability()`) to `flashinfer_benchmark.py`. Backends are now skipped during benchmarks if not supported on the current compute capability. - Previous behavior of unsupported cases was undefined. - Dictionary of supported routine, compute capability, and backend combinations are placed in `flashinfer_benchmark_utils.py` * Updates the `benchmarks/README.md` file to document a support matrix on the bottom. * Updates the sample testlist and outputs with the latest `flashinfer_benchmark.py` usage and sample outputs. * Updates `README.md` to make it clear that SM architecture 75 is supported. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent eb25f07 commit 50319b2

File tree

9 files changed

+686
-336
lines changed

9 files changed

+686
-336
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ FlashInfer also provides C++ API and TVM bindings, please refer to [documentatio
125125

126126
## GPU Support
127127

128-
FlashInfer currently provides support for NVIDIA SM architectures 80 and higher and beta support for 103, 110, 120, and 121.
128+
FlashInfer currently provides support for NVIDIA SM architectures 75 and higher and beta support for 103, 110, 120, and 121.
129129

130130
## Adoption
131131

benchmarks/README.md

Lines changed: 101 additions & 81 deletions
Large diffs are not rendered by default.

benchmarks/routines/attention.py

Lines changed: 49 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
get_device,
1616
print_perf_metrics,
1717
is_close_stats,
18+
filter_backends_by_compute_capability,
1819
)
1920

2021

@@ -241,7 +242,8 @@ def testBatchDecodeWithPagedKVCacheWrapper(args):
241242
# return_lse = not args.no_lse # TO-DO: Add support for this
242243
run_refcheck = args.refcheck
243244

244-
# Derived parameters
245+
backends = filter_backends_by_compute_capability(backends, args.routine, device)
246+
# Check for backend-specific constraints
245247
if "fa2" in backends:
246248
remove_fa2 = False
247249
head_grp_size = (
@@ -279,7 +281,7 @@ def testBatchDecodeWithPagedKVCacheWrapper(args):
279281

280282
if len(backends) == 0:
281283
print("[ERROR] No backends to test. Exiting.")
282-
return
284+
return res
283285

284286
# Storage for timing results and outputs
285287
backend_times = {backend: [] for backend in backends}
@@ -665,6 +667,7 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
665667
# return_lse = not args.no_lse # TO-DO: Add support for this
666668
run_refcheck = args.refcheck
667669

670+
backends = filter_backends_by_compute_capability(backends, args.routine, device)
668671
# Check for backend-specific constraints
669672
if "fa2" in backends:
670673
remove_fa2 = False
@@ -673,16 +676,6 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
673676
remove_fa2 = True
674677
if remove_fa2:
675678
backends.remove("fa2")
676-
if "fa3" in backends:
677-
remove_fa3 = False
678-
device_capability = torch.cuda.get_device_capability()
679-
if device_capability[0] != 9:
680-
print(
681-
f"[INFO] FA3 backend does not support capability {device_capability}. Skipping."
682-
)
683-
remove_fa3 = True
684-
if remove_fa3:
685-
backends.remove("fa3")
686679
if "cudnn" in backends:
687680
remove_cudnn = False
688681
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
@@ -1134,6 +1127,7 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args):
11341127
# return_lse = not args.no_lse # TO-DO: Add support for this
11351128
run_refcheck = args.refcheck
11361129

1130+
backends = filter_backends_by_compute_capability(backends, args.routine, device)
11371131
# Check for backend-specific constraints
11381132
if "cudnn" in backends:
11391133
remove_cudnn = False
@@ -1170,7 +1164,7 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args):
11701164

11711165
if len(backends) == 0:
11721166
print("[ERROR] No backends to test. Exiting.")
1173-
return
1167+
return res
11741168

11751169
# Check for layer-specific constraints
11761170
layer_not_supported = False
@@ -1549,6 +1543,7 @@ def testBatchMLAPagedAttentionWrapper(args):
15491543
causal = False # False for MLA
15501544
run_refcheck = args.refcheck
15511545

1546+
backends = filter_backends_by_compute_capability(backends, args.routine, device)
15521547
# Check for backend-specific constraints
15531548
if "fa2" in backends:
15541549
remove_fa2 = False
@@ -1560,6 +1555,19 @@ def testBatchMLAPagedAttentionWrapper(args):
15601555
remove_fa2 = True
15611556
if remove_fa2:
15621557
backends.remove("fa2")
1558+
if "fa3" in backends:
1559+
remove_fa3 = False
1560+
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
1561+
torch.float8_e4m3fn,
1562+
torch.float8_e5m2,
1563+
]:
1564+
print("[INFO] FA3 backend does not support FP8. Skipping.")
1565+
remove_fa3 = True
1566+
if remove_fa3:
1567+
backends.remove("fa3")
1568+
if len(backends) == 0:
1569+
print("[ERROR] No backends to test. Exiting.")
1570+
return res
15631571

15641572
# Storage for timing results and outputs
15651573
backend_times = {backend: [] for backend in backends}
@@ -1664,30 +1672,32 @@ def testBatchMLAPagedAttentionWrapper(args):
16641672
print(f"[VVERBOSE] {workspace_buffer.shape = }")
16651673

16661674
# Create wrapper
1667-
if "fa2" in backends:
1668-
fi_fa2_mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
1669-
float_workspace_buffer=workspace_buffer,
1670-
use_cuda_graph=is_cuda_graph_compatible,
1671-
qo_indptr=qo_indptr,
1672-
kv_indptr=kv_indptr,
1673-
kv_indices=kv_indices,
1674-
kv_len_arr=actual_seq_lens_kv,
1675-
backend="fa2",
1676-
)
1677-
fi_fa2_mla_wrapper.plan(
1678-
qo_indptr=qo_indptr,
1679-
kv_indptr=kv_indptr,
1680-
kv_indices=kv_indices,
1681-
kv_len_arr=actual_seq_lens_kv,
1682-
num_heads=num_qo_heads,
1683-
head_dim_ckv=head_dim_ckv,
1684-
head_dim_kpe=head_dim_kpe,
1685-
page_size=page_size,
1686-
causal=causal,
1687-
sm_scale=sm_scale,
1688-
q_data_type=q_dtype,
1689-
kv_data_type=kv_dtype,
1690-
)
1675+
backend_wrappers = {}
1676+
for backend in backends:
1677+
if backend in ["fa2", "fa3"]:
1678+
backend_wrappers[backend] = flashinfer.mla.BatchMLAPagedAttentionWrapper(
1679+
float_workspace_buffer=workspace_buffer,
1680+
use_cuda_graph=is_cuda_graph_compatible,
1681+
qo_indptr=qo_indptr,
1682+
kv_indptr=kv_indptr,
1683+
kv_indices=kv_indices,
1684+
kv_len_arr=actual_seq_lens_kv,
1685+
backend=backend,
1686+
)
1687+
backend_wrappers[backend].plan(
1688+
qo_indptr=qo_indptr,
1689+
kv_indptr=kv_indptr,
1690+
kv_indices=kv_indices,
1691+
kv_len_arr=actual_seq_lens_kv,
1692+
num_heads=num_qo_heads,
1693+
head_dim_ckv=head_dim_ckv,
1694+
head_dim_kpe=head_dim_kpe,
1695+
page_size=page_size,
1696+
causal=causal,
1697+
sm_scale=sm_scale,
1698+
q_data_type=q_dtype,
1699+
kv_data_type=kv_dtype,
1700+
)
16911701

16921702
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
16931703
q = q.to(q_dtype)
@@ -1699,8 +1709,8 @@ def testBatchMLAPagedAttentionWrapper(args):
16991709
kv_cache = kv_cache.to(kv_dtype)
17001710

17011711
def run_backend_wrapper(backend):
1702-
if backend == "fa2":
1703-
return fi_fa2_mla_wrapper.run(
1712+
if backend in ["fa2", "fa3"]:
1713+
return backend_wrappers[backend].run(
17041714
q_nope, q_pe, ckv_cache, kpe_cache, return_lse=False
17051715
)
17061716
if backend == "trtllm-gen-native":

benchmarks/routines/flashinfer_benchmark_utils.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22

33
from flashinfer.testing.utils import set_seed
4+
from flashinfer.utils import get_compute_capability
45

56
# Output columns for the test results.
67
output_column_dict = {
@@ -156,3 +157,151 @@ def dtype_str_to_torch_dtype(dtype_str):
156157
return torch.float8_e5m2
157158
else:
158159
raise ValueError(f"Unsupported dtype: {dtype_str}")
160+
161+
162+
routine_cc_to_supported_backends = {
163+
# ATTENTION
164+
"BatchDecodeWithPagedKVCacheWrapper": {
165+
"7.5": ["fa2"],
166+
"8.0": ["fa2", "fa2_tc", "cudnn"],
167+
"8.6": ["fa2", "fa2_tc", "cudnn"],
168+
"8.9": ["fa2", "fa2_tc", "cudnn"],
169+
"9.0": ["fa2", "fa2_tc", "cudnn"],
170+
"10.0": ["fa2", "fa2_tc", "cudnn", "trtllm-gen", "trtllm-gen-native"],
171+
"10.3": ["fa2", "fa2_tc", "cudnn", "trtllm-gen", "trtllm-gen-native"],
172+
"12.0": ["fa2", "fa2_tc", "cudnn"],
173+
},
174+
"BatchPrefillWithPagedKVCacheWrapper": {
175+
"7.5": [],
176+
"8.0": ["fa2", "cudnn"],
177+
"8.6": ["fa2", "cudnn"],
178+
"8.9": ["fa2", "cudnn"],
179+
"9.0": ["fa2", "fa3", "cudnn"],
180+
"10.0": ["fa2", "cudnn", "trtllm-gen"],
181+
"10.3": ["fa2", "cudnn", "trtllm-gen"],
182+
"12.0": ["fa2", "cudnn"],
183+
},
184+
"BatchPrefillWithRaggedKVCacheWrapper": {
185+
"7.5": [],
186+
"8.0": ["fa2", "cudnn"],
187+
"8.6": ["fa2", "cudnn"],
188+
"8.9": ["fa2", "cudnn"],
189+
"9.0": ["fa2", "fa3", "cudnn"],
190+
"10.0": ["fa2", "cudnn", "cutlass"],
191+
"10.3": ["fa2", "cudnn", "cutlass"],
192+
"12.0": ["fa2", "cudnn"],
193+
},
194+
"BatchMLAPagedAttentionWrapper": {
195+
"7.5": [],
196+
"8.0": ["fa2"],
197+
"8.6": ["fa2"],
198+
"8.9": ["fa2"],
199+
"9.0": ["fa2", "fa3"],
200+
"10.0": ["fa2", "trtllm-gen-native"],
201+
"10.3": ["fa2", "trtllm-gen-native"],
202+
"12.0": ["fa2"],
203+
},
204+
# GEMM
205+
"gemm_fp8_nt_groupwise": {
206+
"7.5": [],
207+
"8.0": [],
208+
"8.6": [],
209+
"8.9": [],
210+
"9.0": [],
211+
"10.0": ["cutlass"],
212+
"10.3": ["cutlass"],
213+
"12.0": [],
214+
},
215+
"group_gemm_fp8_nt_groupwise": {
216+
"7.5": [],
217+
"8.0": [],
218+
"8.6": [],
219+
"8.9": [],
220+
"9.0": [],
221+
"10.0": ["cutlass"],
222+
"10.3": ["cutlass"],
223+
"12.0": [],
224+
},
225+
"bmm_fp8": {
226+
"7.5": [],
227+
"8.0": [],
228+
"8.6": [],
229+
"8.9": ["cudnn", "cublas"],
230+
"9.0": ["cudnn", "cublas"],
231+
"10.0": ["cudnn", "cublas", "cutlass"],
232+
"10.3": ["cudnn", "cublas", "cutlass"],
233+
"12.0": ["cudnn", "cublas"],
234+
},
235+
"mm_fp4": {
236+
"7.5": [],
237+
"8.0": [],
238+
"8.6": [],
239+
"8.9": [],
240+
"9.0": [],
241+
"10.0": ["cudnn", "trtllm", "cutlass"],
242+
"10.3": ["cudnn", "trtllm", "cutlass"],
243+
"12.0": ["cudnn"],
244+
},
245+
# MOE
246+
"trtllm_fp4_block_scale_moe": {
247+
"7.5": [],
248+
"8.0": [],
249+
"8.6": [],
250+
"8.9": [],
251+
"9.0": [],
252+
"10.0": ["trtllm"],
253+
"10.3": ["trtllm"],
254+
"12.0": [],
255+
},
256+
"trtllm_fp8_block_scale_moe": {
257+
"7.5": [],
258+
"8.0": [],
259+
"8.6": [],
260+
"8.9": [],
261+
"9.0": [],
262+
"10.0": ["trtllm"],
263+
"10.3": ["trtllm"],
264+
"12.0": [],
265+
},
266+
"trtllm_fp8_per_tensor_scale_moe": {
267+
"7.5": [],
268+
"8.0": [],
269+
"8.6": [],
270+
"8.9": [],
271+
"9.0": [],
272+
"10.0": ["trtllm"],
273+
"10.3": ["trtllm"],
274+
"12.0": [],
275+
},
276+
"cutlass_fused_moe": {
277+
"7.5": [],
278+
"8.0": [],
279+
"8.6": [],
280+
"8.9": [],
281+
"9.0": [],
282+
"10.0": ["cutlass"],
283+
"10.3": ["cutlass"],
284+
"12.0": [],
285+
},
286+
}
287+
288+
289+
def filter_backends_by_compute_capability(backends, routine, device):
290+
# FlashInfer currently does not have an isSupported() function that checks support.
291+
# WAR: Use helper function to check support.
292+
major, minor = get_compute_capability(device)
293+
compute_capability = f"{major}.{minor}"
294+
295+
# If the compute capability is not supported, return an empty list.
296+
cc_to_supported_backends = routine_cc_to_supported_backends[routine]
297+
supported_backends = cc_to_supported_backends.get(compute_capability, [])
298+
backends_to_remove = []
299+
for backend in backends:
300+
if backend not in supported_backends:
301+
backends_to_remove.append(backend)
302+
for backend in backends_to_remove:
303+
backends.remove(backend)
304+
print(
305+
f"[WARNING] {backend} for routine {routine} is not supported on compute capability {compute_capability}. Skipping."
306+
)
307+
return backends

0 commit comments

Comments
 (0)