Skip to content

Commit b7be894

Browse files
authored
Added xfail for mx_fp4 matmul on SM120 (#1766)
<!-- .github/pull_request_template.md --> ## 📌 Description * A library bug is prevening mx_fp4 matmul on SM120 * While waiting for the patch, this test is now xfailed * Added a LibraryError class was added to handle these issues in general ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [V] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [V] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 4e7969e commit b7be894

File tree

3 files changed

+53
-26
lines changed

3 files changed

+53
-26
lines changed

flashinfer/gemm.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@
3838
last_positive_power_of_2,
3939
)
4040
from .jit.cubin_loader import get_cubin
41-
from .utils import is_sm100a_supported, is_sm120a_supported, is_sm121a_supported
41+
from .utils import (
42+
is_sm100a_supported,
43+
is_sm120a_supported,
44+
is_sm121a_supported,
45+
LibraryError,
46+
)
4247

4348
CUDNN_AVAILABLE = False
4449
try:
@@ -2112,6 +2117,15 @@ def mm_fp4(
21122117
raise ValueError("TRTLLM FP4 GEMM is not supported on SM110.")
21132118
if backend != "cudnn" and not use_nvfp4:
21142119
raise ValueError("Only cudnn FP4 GEMM supports mxfp4 quantization.")
2120+
if (
2121+
backend == "cudnn"
2122+
and not use_nvfp4
2123+
and _match_sm_version(a.device, ["120"])
2124+
and cudnn.backend_version() < 91400
2125+
):
2126+
raise LibraryError(
2127+
"cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0."
2128+
)
21152129

21162130
# allocate the output tensor if not provided
21172131
if out is None:

flashinfer/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,15 @@ class TensorLayout(Enum):
5252

5353

5454
class GPUArchitectureError(Exception):
55-
def __init__(self, msg: str):
56-
self.msg = msg
57-
super().__init__(self.msg)
55+
"""Custom exception for GPU architecture-related errors."""
5856

59-
def __str__(self):
60-
return self.msg
57+
pass
6158

62-
def __repr__(self):
63-
return self.msg
59+
60+
class LibraryError(Exception):
61+
"""Custom exception for library-related errors."""
62+
63+
pass
6464

6565

6666
def _expand_5d(x: torch.Tensor, kv_layout: str) -> torch.Tensor:

tests/test_mm_fp4.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
nvfp4_quantize,
99
mxfp4_quantize,
1010
)
11-
from flashinfer.utils import get_compute_capability
11+
from flashinfer.utils import get_compute_capability, LibraryError
1212

1313

1414
# TODO: Consdier splitting this function up for the various backends
@@ -25,10 +25,10 @@ def test_mm_fp4(
2525
):
2626
use_nvfp4 = fp4_type == "nvfp4"
2727

28+
compute_capability = get_compute_capability(torch.device(device="cuda"))
2829
if backend == "trtllm":
2930
if res_dtype == torch.float16:
3031
pytest.skip("Skipping test for trtllm fp4 with float16")
31-
compute_capability = get_compute_capability(torch.device(device="cuda"))
3232
if compute_capability[0] in [11, 12]:
3333
pytest.skip("trtllm gemm does not support SM110/SM120/SM121 GPUs.")
3434
if not use_128x4_sf_layout and backend != "trtllm":
@@ -71,23 +71,36 @@ def test_mm_fp4(
7171

7272
res = torch.empty([m, n], device="cuda", dtype=res_dtype)
7373

74-
with autotune(auto_tuning):
75-
mm_fp4(
76-
input_fp4,
77-
mat2_fp4.T,
78-
input_inv_s,
79-
mat2_inv_s.T,
80-
alpha,
81-
res_dtype,
82-
res,
83-
block_size=block_size,
84-
use_8x4_sf_layout=not use_128x4_sf_layout,
85-
backend=backend,
86-
use_nvfp4=use_nvfp4,
87-
)
74+
try:
75+
with autotune(auto_tuning):
76+
mm_fp4(
77+
input_fp4,
78+
mat2_fp4.T,
79+
input_inv_s,
80+
mat2_inv_s.T,
81+
alpha,
82+
res_dtype,
83+
res,
84+
block_size=block_size,
85+
use_8x4_sf_layout=not use_128x4_sf_layout,
86+
backend=backend,
87+
use_nvfp4=use_nvfp4,
88+
)
8889

89-
cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0)
90-
assert cos_sim > 0.97
90+
cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0)
91+
assert cos_sim > 0.97
92+
except LibraryError:
93+
# TODO: Remove this check once cuDNN backend version is updated to 9.14.0
94+
if (
95+
backend == "cudnn"
96+
and not use_nvfp4
97+
and (compute_capability[0] == 12 and compute_capability[1] == 0)
98+
):
99+
pytest.xfail(
100+
"cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0."
101+
)
102+
else:
103+
pytest.fail("Unexpected LibraryError")
91104

92105

93106
if __name__ == "__main__":

0 commit comments

Comments
 (0)