Skip to content

Commit 4e7969e

Browse files
authored
tests: skip non SM100/103 for grouped deepgemm (#1767)
<!-- .github/pull_request_template.md --> ## 📌 Description skip test_fp8_groupwise_group_deepgemm and test_fp8_groupwise_group_deepgemm where SM not 100 and 103, add relevant checks in library ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: jimmzhou <[email protected]>
1 parent 4fe837f commit 4e7969e

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

flashinfer/gemm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3078,6 +3078,11 @@ def group_deepgemm_fp8_nt_groupwise(
30783078
"""
30793079
from flashinfer.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
30803080

3081+
if not _match_sm_version(a.device, ["100", "103"]):
3082+
raise ValueError(
3083+
"m_grouped_fp8_gemm_nt_contiguous is only supported on SM100, SM100, SM103."
3084+
)
3085+
30813086
if out is None:
30823087
out_dtype = out_dtype or torch.bfloat16
30833088
out = torch.empty(a.shape[0], b.shape[1], dtype=out_dtype, device=a.device)
@@ -3206,6 +3211,11 @@ def batch_deepgemm_fp8_nt_groupwise(
32063211
"""
32073212
from flashinfer.deep_gemm import m_grouped_fp8_gemm_nt_masked
32083213

3214+
if not _match_sm_version(a.device, ["100", "103"]):
3215+
raise ValueError(
3216+
"m_grouped_fp8_gemm_nt_masked is only supported on SM100, SM103."
3217+
)
3218+
32093219
if out is None:
32103220
out_dtype = out_dtype or torch.bfloat16
32113221
out = torch.empty(

tests/test_groupwise_scaled_gemm_fp8.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,11 @@ def test_fp8_groupwise_group_deepgemm(
202202
group_size,
203203
out_dtype,
204204
):
205+
compute_capability = get_compute_capability(torch.device(device="cuda"))
206+
if compute_capability[0] != 10:
207+
pytest.skip(
208+
"group_deepgemm_fp8_nt_groupwise is only supported on SM100, SM103 in trtllm backend."
209+
)
205210
torch.random.manual_seed(0)
206211
m_per_group = m // group_size
207212
if m_per_group < 128:
@@ -245,6 +250,11 @@ def test_fp8_groupwise_batch_deepgemm_masked(
245250
group_size,
246251
out_dtype,
247252
):
253+
compute_capability = get_compute_capability(torch.device(device="cuda"))
254+
if compute_capability[0] != 10:
255+
pytest.skip(
256+
"batch_deepgemm_fp8_nt_groupwise is only supported on SM100, SM103."
257+
)
248258
torch.random.manual_seed(0)
249259
n, k = nk
250260
a = torch.randn((group_size, m, k), device="cuda", dtype=torch.float32)

0 commit comments

Comments
 (0)