-
Notifications
You must be signed in to change notification settings - Fork 435
feat: masked layout fp4 gemm using cute-dsl #1331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: masked layout fp4 gemm using cute-dsl #1331
Conversation
Note Gemini is unable to generate a summary for this pull request due to the file types involved not being currently supported. |
…dsl-fp4-masked-layout
…dsl-fp4-masked-layout
…dsl-fp4-masked-layout
There are still a lot of work to be done:
Left them for future PRs, let's unblock users and test functionality first. |
return self._num_tiles_executed | ||
|
||
|
||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's safe to delete this docstring?
a_tensor = cute.make_tensor( | ||
a_ptr, | ||
layout=cute.make_ordered_layout( | ||
(self._m, self._k, self._l), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a non-blocking comments.
This assumes static shape if it's passed by members. Just double check it's what we are expecting here? To support dynamic shape, m/k/l must be passed via run_cute_ptr
's argument list as Int32
type I believe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think currently it's okay to assume we have static shapes, the number of groups should depend on the TP/EP size and N/K are fixed, we can compile one for each cudagraph configuration. For M we can just set a maximum possible value and the kernel execution time will only depend on the value of mask_m
tensor, not M
.
cc @kaixih for confirmation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool. I think it's also one of the advantage of using jit
here. You can also selectively choose static shape which usually end-up with better SASS.
📌 Description
Implement fp4 gemm (w/ masked layout) requested in sgl-project/sglang#7994
Adapted from cutlass's dense_blockscaled_gemm_persistent example, with DeepGEMM style tile-scheduler
🔍 Related Issues
sgl-project/sglang#7994
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commit
by runningpip install pre-commit
(or used your preferred method).pre-commit install
.pre-commit run --all-files
and fixed any reported issues.🧪 Tests
unittest
, etc.).Reviewer Notes
cc @fzyzcjy
Co-authored-by: Avery Huang [email protected]