Skip to content

Commit 8e926de

Browse files
fzyzcjyyyihuang
andauthored
Fix cute dsl gemm API wrong arg name and silent error when passing wrong kwargs (#1619)
<!-- .github/pull_request_template.md --> ## 📌 Description * fix name error: not "mm" but "mn" * make it explicitly fail now. o/w e.g. when users need to pass in `signals` and wrongly use a old flashinfer version, they will not realize it and get confusing errors. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: Avery Yingyi Huang <[email protected]>
1 parent da937d7 commit 8e926de

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

flashinfer/cute_dsl/blockscaled_gemm.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2811,13 +2811,15 @@ def grouped_gemm_nt_masked(
28112811
# Note: only support deepgemm-like shape for now
28122812
k = k * 2
28132813

2814-
mma_tiler_mn = kwargs.get("mma_tiler_mm", (128, 128))
2815-
cluster_shape_mn = kwargs.get("cluster_shape_mm", (1, 1))
2814+
mma_tiler_mn = kwargs.pop("mma_tiler_mn", (128, 128))
2815+
cluster_shape_mn = kwargs.pop("cluster_shape_mn", (1, 1))
28162816
if sm_count is None:
28172817
sm_count = get_num_sm(a_torch.device)
28182818

2819-
alpha = kwargs.get("alpha")
2820-
alpha_dtype = kwargs.get("alpha_dtype")
2819+
alpha = kwargs.pop("alpha", None)
2820+
alpha_dtype = kwargs.pop("alpha_dtype", None)
2821+
2822+
assert len(kwargs) == 0, f"Unsupported kwargs: {kwargs}"
28212823

28222824
major, minor = get_compute_capability(a_torch.device)
28232825
if major == 11 and minor == 0:

0 commit comments

Comments
 (0)