Skip to content

(draft) [feat] FP8 MoE per-channel quant support#2809

Draft
raayandhar wants to merge 1 commit intoflashinfer-ai:mainfrom
raayandhar:users/rdhar/moe-per-channel-quant
Draft

(draft) [feat] FP8 MoE per-channel quant support#2809
raayandhar wants to merge 1 commit intoflashinfer-ai:mainfrom
raayandhar:users/rdhar/moe-per-channel-quant

Conversation

@raayandhar
Copy link
Contributor

@raayandhar raayandhar commented Mar 17, 2026

📌 Description

See issue. Seems like the TRTLLM kernel already exposes some support for this but was set to nullptr. Mostly plumbing work to try and get this to work.

Need to test, unsure if this is functional + if we need cubin.

🔍 Related Issues

#2419

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Signed-off-by: raayandhar <raayan.dhar@gmail.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 17, 2026

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f7990570-e01b-4dbb-893f-f71bd6dd8e14

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

You can customize the tone of the review comments and chat replies.

Configure the tone_instructions setting to customize the tone of the review comments and chat replies. For example, you can set the tone to Act like a strict teacher, Act like a pirate and more.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces support for FP8 per-channel quantization in Mixture-of-Experts (MoE) operations within the FlashInfer framework. It extends the existing TRTLLM fused MoE kernel to handle per-channel scaling for weights, providing a new quantization mode alongside the previously supported per-tensor and block-scale methods. This enhancement aims to improve the efficiency and accuracy of MoE models by allowing more granular quantization of model weights.

Highlights

  • New Quantization Type: Introduced PerChannelFp8 as a new FP8 quantization type for Mixture-of-Experts (MoE) operations, enabling more granular scaling of weights.
  • TRTLLM Kernel Integration: Implemented a dedicated Fp8PerChannelLauncher class and trtllm_fp8_per_channel_scale_moe function to handle the specifics of per-channel FP8 quantization within the TRTLLM fused MoE kernel.
  • Python API Exposure: Exposed a new Python API function, trtllm_fp8_per_channel_scale_moe, allowing users to leverage this new quantization mode from the FlashInfer library.
  • Testing Framework Update: Integrated the per-channel FP8 MoE support into the testing framework, including new quantization logic, kernel execution, and reference computations to ensure correctness.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/trtllm_fused_moe_kernel_launcher.cu
    • Added PerChannelFp8 to the Fp8QuantizationType enum.
    • Implemented the Fp8PerChannelLauncher class, inheriting from FusedMoeLauncher, to manage per-channel FP8 MoE operations.
    • Added the trtllm_fp8_per_channel_scale_moe function to handle the execution of per-channel FP8 MoE.
  • csrc/trtllm_fused_moe_runner.cu
    • Modified Runner::run methods for mPermuteGemm1 and GEMM2 to accept an optional perChannelWeightScale argument.
  • flashinfer/fused_moe/init.py
    • Exported the new trtllm_fp8_per_channel_scale_moe function.
  • flashinfer/fused_moe/core.py
    • Extended Fp8QuantizationType enum with PerTensorFp8 and PerChannelFp8 values.
    • Integrated the trtllm_fp8_per_channel_scale_moe call within the forward method of the MoE runner, conditional on the Fp8QuantizationType.
    • Registered custom and fake operations for trtllm_fp8_per_channel_scale_moe.
  • include/flashinfer/trtllm/fused_moe/runner.h
    • Updated Runner::run method signatures to include an optional perChannelWeightScale parameter.
    • Added gemm1_per_channel_weight_scale, gemm1_per_channel_gate_weight_scale, and gemm2_per_channel_weight_scale fields to the MoERunnerArgs struct.
  • tests/moe/test_trtllm_gen_fused_moe.py
    • Introduced the FP8PerChannelMoe class to encapsulate per-channel FP8 quantization logic, including quantize_weights, quantize_inputs, prepare_static_weights_for_kernel, call_moe, and compute_reference methods.
    • Added FP8PerChannelMoe to various pytest.param lists to enable testing of the new quantization mode.
    • Added gemm1_per_channel_scales and gemm2_per_channel_scales to moe_args_dequant constructor.
    • Implemented quant_fp8_per_channel and quant_fp8_per_channel_batches for per-channel quantization.
    • Updated run_moe_dequant to handle FP8_PER_CHANNEL mode.
    • Added run_moe_reference_per_channel_scale_fp8 for reference computation.
  • tests/moe/utils.py
    • Added FP8_PER_CHANNEL to the QuantMode enum.
    • Included QuantMode.FP8_PER_CHANNEL in NON_GATED_ACTIVATION_SUPPORTED_QUANT_MODES.
Activity
  • The author, raayandhar, opened this pull request as a draft to introduce FP8 MoE per-channel quantization support.
  • Pre-commit checks and initial tests have been completed by the author.
  • The author noted that further testing is required and expressed uncertainty about the functionality and the need for cubin files.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for FP8 per-channel quantization in MoE kernels, which is a significant feature addition. The changes are well-structured, adding a new Fp8PerChannelLauncher and plumbing the necessary parameters through the C++ and Python layers, including new tests. The implementation largely follows existing patterns in the codebase. I've identified a couple of minor areas for improvement in the C++ code to enhance robustness by avoiding hardcoded values related to activation functions.

Comment on lines +853 to +858
TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.ndim(), 2)
<< "gemm1_per_channel_weight_scale must be 2D [local_num_experts, 2*intermediate_size].";
TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.size(0), args->local_num_experts)
<< "gemm1_per_channel_weight_scale dim 0 must match local_num_experts.";
TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.size(1), 2 * args->intermediate_size)
<< "gemm1_per_channel_weight_scale dim 1 must be 2*intermediate_size.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The dimension checks and error messages for gemm1_per_channel_weight_scale_ hardcode the multiplier 2, which assumes a gated activation. It would be more robust to use the intermediate_size_factor member from the base class, which is correctly set based on the activation type. This will ensure correctness if non-gated activations are used with this launcher in the future.

    TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.ndim(), 2)
        << "gemm1_per_channel_weight_scale must be 2D [local_num_experts, "
           "intermediate_size_factor*intermediate_size].";
    TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.size(0), args->local_num_experts)
        << "gemm1_per_channel_weight_scale dim 0 must match local_num_experts.";
    TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.size(1),
                      intermediate_size_factor * args->intermediate_size)
        << "gemm1_per_channel_weight_scale dim 1 must match intermediate_size_factor * "
           "intermediate_size.";

Comment on lines +862 to +868
TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.ndim(), 2)
<< "gemm1_per_channel_gate_weight_scale must be 2D [local_num_experts, "
"2*intermediate_size].";
TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.size(0), args->local_num_experts)
<< "gemm1_per_channel_gate_weight_scale dim 0 must match local_num_experts.";
TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.size(1), 2 * args->intermediate_size)
<< "gemm1_per_channel_gate_weight_scale dim 1 must be 2*intermediate_size.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the check for gemm1_per_channel_weight_scale_, the dimension checks and error messages for gemm1_per_channel_gate_weight_scale_ hardcode the multiplier 2. Using intermediate_size_factor would make this more robust and consistent.

    TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.ndim(), 2)
        << "gemm1_per_channel_gate_weight_scale must be 2D [local_num_experts, "
           "intermediate_size_factor*intermediate_size].";
    TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.size(0), args->local_num_experts)
        << "gemm1_per_channel_gate_weight_scale dim 0 must match local_num_experts.";
    TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.size(1),
                      intermediate_size_factor * args->intermediate_size)
        << "gemm1_per_channel_gate_weight_scale dim 1 must match intermediate_size_factor * "
           "intermediate_size.";

@raayandhar
Copy link
Contributor Author

Will test later today

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants