-
Notifications
You must be signed in to change notification settings - Fork 31.8k
[Quantization] Add cutlass kernel for FP8 #43304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
|
||
| _quantization_kernel = get_kernel("RedHatAI/quantization") | ||
| except Exception as e: | ||
| logger.warning_once(f"Failed to load CUTLASS quantization kernel: {e}. Falling back to Triton.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to also log that we're using the Redhat kernel in case it was successfully loaded?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we can do that I think
| Check if CUTLASS blockwise FP8 matmul is supported for the given block size. | ||
| CUTLASS blockwise kernels require: | ||
| - SM90+ (Hopper or newer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine IMO. When hardware is available, users should be able to max them out!
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks ! Just a few nits
| kernel = _get_quantization_kernel() | ||
| if kernel is None: | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we also need to check that kernels is installed no?
| # Global for the CUTLASS quantization kernel (lazily loaded) | ||
| _quantization_kernel = None | ||
|
|
||
|
|
||
| def _get_quantization_kernel(): | ||
| """Lazily load the CUTLASS quantization kernel from HuggingFace Hub.""" | ||
| global _quantization_kernel | ||
| if _quantization_kernel is None: | ||
| try: | ||
| from .hub_kernels import get_kernel | ||
|
|
||
| _quantization_kernel = get_kernel("RedHatAI/quantization") | ||
| except Exception as e: | ||
| logger.warning_once(f"Failed to load CUTLASS quantization kernel: {e}. Falling back to Triton.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of a single kernel, can we try to create a dict where we store multiple kernels ? We can leave this for a follow-up PR but the idea would be to also move the triton kernels in kernels.
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks ! Let's merge this
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43304&sha=cfd4b9 |
What does this PR do?
adds the cutlass kernel for scaled matmul, the performance is much better than triton for the specific block size : (128, 128):
All FP8 tests passing !