Skip to content

Commit d910f9a

Browse files
authored
Improve graph caching of cudnn graph (#1887)
The graph caching depends on alpha and device. This can be improved, by just checking if alpha is not none. <!-- .github/pull_request_template.md --> ## 📌 Description Take only if alpha is present in the graph creation. Actual value of alpha is to be bound later. ## 🔍 Related Issues None ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent c08b529 commit d910f9a

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

flashinfer/gemm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,7 +1198,7 @@ def build_cudnn_gemm_block_scale_dequantize_graph(
11981198
o_type,
11991199
block_size,
12001200
device,
1201-
alpha,
1201+
alpha_is_not_none,
12021202
use_nvfp4,
12031203
):
12041204
_check_cudnn_availability()
@@ -1251,7 +1251,7 @@ def build_cudnn_gemm_block_scale_dequantize_graph(
12511251

12521252
c_final_cudnn_tensor = c_tensor
12531253

1254-
if alpha is not None:
1254+
if alpha_is_not_none:
12551255
global_scale_cudnn_tensor = graph.tensor(
12561256
name="global_scale",
12571257
dim=(1, 1, 1),
@@ -1280,7 +1280,7 @@ def build_cudnn_gemm_block_scale_dequantize_graph(
12801280

12811281
# WAR: The alpha (contains the global scale) is not supported by the cuBLAS backend (eng0)
12821282
# in older cuDNN versions, so we deselect it.
1283-
if (alpha is not None) and (not _is_cublas_fp4_available_in_cudnn()):
1283+
if (alpha_is_not_none) and (not _is_cublas_fp4_available_in_cudnn()):
12841284
graph.deselect_engines(["eng0"])
12851285
graph.check_support()
12861286
graph.build_plans()
@@ -1710,7 +1710,7 @@ def mm_fp4(
17101710
_torch_data_type_to_cudnn_data_type(out_dtype),
17111711
block_size,
17121712
a.device,
1713-
alpha,
1713+
alpha is not None,
17141714
use_nvfp4,
17151715
)
17161716

0 commit comments

Comments
 (0)