You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
bugfix: do cudnn related error check only when cudnn backend is enabled. (#1377)
<!-- .github/pull_request_template.md -->
## 📌 Description
we only need check cudnn available when user ask cudnn as backend.
## 🔍 Related Issues
<!-- Link any related issues here -->
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).
## Reviewer Notes
<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
Copy file name to clipboardExpand all lines: flashinfer/gemm.py
+15-12Lines changed: 15 additions & 12 deletions
Original file line number
Diff line number
Diff line change
@@ -1349,8 +1349,6 @@ def mm_fp4(
1349
1349
>>> out.shape
1350
1350
torch.Size([48, 256])
1351
1351
"""
1352
-
_check_cudnn_fp4_availability()
1353
-
1354
1352
# pre-check the input tensor, block scale tensor and alpha tensor
1355
1353
ifa.ndim!=2orb.ndim!=2:
1356
1354
raiseValueError(f"mm_fp4 accepts 2d tensors, got {a.shape} and {b.shape}")
@@ -1397,21 +1395,26 @@ def mm_fp4(
1397
1395
dtype=out_dtype,
1398
1396
)
1399
1397
1400
-
# the fp4 cudnn graph will be shared for both mm and bmm, so here we need to get the 3d shape and stride including the batch dimension for both input and block scale tensors.
0 commit comments