(draft) [feat] FP8 MoE per-channel quant support#2809
(draft) [feat] FP8 MoE per-channel quant support#2809raayandhar wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip You can customize the tone of the review comments and chat replies.Configure the |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces support for FP8 per-channel quantization in Mixture-of-Experts (MoE) operations within the FlashInfer framework. It extends the existing TRTLLM fused MoE kernel to handle per-channel scaling for weights, providing a new quantization mode alongside the previously supported per-tensor and block-scale methods. This enhancement aims to improve the efficiency and accuracy of MoE models by allowing more granular quantization of model weights. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for FP8 per-channel quantization in MoE kernels, which is a significant feature addition. The changes are well-structured, adding a new Fp8PerChannelLauncher and plumbing the necessary parameters through the C++ and Python layers, including new tests. The implementation largely follows existing patterns in the codebase. I've identified a couple of minor areas for improvement in the C++ code to enhance robustness by avoiding hardcoded values related to activation functions.
| TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.ndim(), 2) | ||
| << "gemm1_per_channel_weight_scale must be 2D [local_num_experts, 2*intermediate_size]."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.size(0), args->local_num_experts) | ||
| << "gemm1_per_channel_weight_scale dim 0 must match local_num_experts."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.size(1), 2 * args->intermediate_size) | ||
| << "gemm1_per_channel_weight_scale dim 1 must be 2*intermediate_size."; |
There was a problem hiding this comment.
The dimension checks and error messages for gemm1_per_channel_weight_scale_ hardcode the multiplier 2, which assumes a gated activation. It would be more robust to use the intermediate_size_factor member from the base class, which is correctly set based on the activation type. This will ensure correctness if non-gated activations are used with this launcher in the future.
TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.ndim(), 2)
<< "gemm1_per_channel_weight_scale must be 2D [local_num_experts, "
"intermediate_size_factor*intermediate_size].";
TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.size(0), args->local_num_experts)
<< "gemm1_per_channel_weight_scale dim 0 must match local_num_experts.";
TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.size(1),
intermediate_size_factor * args->intermediate_size)
<< "gemm1_per_channel_weight_scale dim 1 must match intermediate_size_factor * "
"intermediate_size.";
| TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.ndim(), 2) | ||
| << "gemm1_per_channel_gate_weight_scale must be 2D [local_num_experts, " | ||
| "2*intermediate_size]."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.size(0), args->local_num_experts) | ||
| << "gemm1_per_channel_gate_weight_scale dim 0 must match local_num_experts."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.size(1), 2 * args->intermediate_size) | ||
| << "gemm1_per_channel_gate_weight_scale dim 1 must be 2*intermediate_size."; |
There was a problem hiding this comment.
Similar to the check for gemm1_per_channel_weight_scale_, the dimension checks and error messages for gemm1_per_channel_gate_weight_scale_ hardcode the multiplier 2. Using intermediate_size_factor would make this more robust and consistent.
TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.ndim(), 2)
<< "gemm1_per_channel_gate_weight_scale must be 2D [local_num_experts, "
"intermediate_size_factor*intermediate_size].";
TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.size(0), args->local_num_experts)
<< "gemm1_per_channel_gate_weight_scale dim 0 must match local_num_experts.";
TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.size(1),
intermediate_size_factor * args->intermediate_size)
<< "gemm1_per_channel_gate_weight_scale dim 1 must match intermediate_size_factor * "
"intermediate_size.";
|
Will test later today |
📌 Description
See issue. Seems like the TRTLLM kernel already exposes some support for this but was set to
nullptr. Mostly plumbing work to try and get this to work.Need to test, unsure if this is functional + if we need cubin.
🔍 Related Issues
#2419
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes