Skip to content

Commit b79b78c

Browse files
authored
bugfix: do cudnn related error check only when cudnn backend is enabled. (#1377)
<!-- .github/pull_request_template.md --> ## 📌 Description we only need check cudnn available when user ask cudnn as backend. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 40471bf commit b79b78c

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

flashinfer/gemm.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,8 +1349,6 @@ def mm_fp4(
13491349
>>> out.shape
13501350
torch.Size([48, 256])
13511351
"""
1352-
_check_cudnn_fp4_availability()
1353-
13541352
# pre-check the input tensor, block scale tensor and alpha tensor
13551353
if a.ndim != 2 or b.ndim != 2:
13561354
raise ValueError(f"mm_fp4 accepts 2d tensors, got {a.shape} and {b.shape}")
@@ -1397,21 +1395,26 @@ def mm_fp4(
13971395
dtype=out_dtype,
13981396
)
13991397

1400-
# the fp4 cudnn graph will be shared for both mm and bmm, so here we need to get the 3d shape and stride including the batch dimension for both input and block scale tensors.
1401-
real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a)
1402-
real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b)
1403-
batch = real_a_shape[0]
1404-
expanded_a_descale_shape, expanded_a_descale_stride = (
1405-
_expand_block_scale_tensor_shape(a_descale, batch)
1406-
)
1407-
expanded_b_descale_shape, expanded_b_descale_stride = (
1408-
_expand_block_scale_tensor_shape(b_descale, batch)
1409-
)
14101398
workspace_buffer = _get_cache_buf(
14111399
"mm_fp4_workspace", DEFAULT_WORKSPACE_SIZE, a.device
14121400
)
14131401

14141402
if backend == "cudnn":
1403+
_check_cudnn_fp4_availability()
1404+
1405+
# the fp4 cudnn graph will be shared for both mm and bmm, so
1406+
# here we need to get the 3d shape and stride including the
1407+
# batch dimension for both input and block scale tensors.
1408+
real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a)
1409+
real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b)
1410+
batch = real_a_shape[0]
1411+
expanded_a_descale_shape, expanded_a_descale_stride = (
1412+
_expand_block_scale_tensor_shape(a_descale, batch)
1413+
)
1414+
expanded_b_descale_shape, expanded_b_descale_stride = (
1415+
_expand_block_scale_tensor_shape(b_descale, batch)
1416+
)
1417+
14151418
# build the fp4 cudnn graph
14161419
graph = build_cudnn_gemm_block_scale_dequantize_graph(
14171420
real_a_shape,

0 commit comments

Comments
 (0)