Skip to content

Support Sigmoid (sigmoid+topk) routing function#2869

Draft
EdalatiAli wants to merge 2 commits intoflashinfer-ai:mainfrom
EdalatiAli:sigmoid_routing
Draft

Support Sigmoid (sigmoid+topk) routing function#2869
EdalatiAli wants to merge 2 commits intoflashinfer-ai:mainfrom
EdalatiAli:sigmoid_routing

Conversation

@EdalatiAli
Copy link

@EdalatiAli EdalatiAli commented Mar 24, 2026

📌 Description

Depends on #2803 .

This PR adds RoutingMethodType.Sigmoid to support a routing function that applies sigmoid before topk (without renormalization) to be used by MoE layers that use this routing function.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 24, 2026

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c0cc0cc2-d95b-4302-834c-f3829927170f

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Mixture-of-Experts (MoE) routing capabilities by introducing new sigmoid-based routing functions and a robust, policy-driven architecture. The changes aim to provide greater flexibility and support for diverse MoE model designs, ensuring efficient and accurate expert selection under various activation and normalization schemes. The refactoring also streamlines the codebase by consolidating common routing utilities and optimizing kernel instantiations.

Highlights

  • New Routing Methods: Introduced RoutingMethodType.Sigmoid and RoutingMethodType.SigmoidRenorm to support MoE layers that apply a sigmoid activation before TopK selection, with or without subsequent renormalization.
  • Policy-Based Routing Refactor: Refactored the core routing logic into a flexible policy-based system (routingCustom), allowing for modular definition of preprocessing (e.g., Softmax, Sigmoid, Sigmoid+Bias) and postprocessing (e.g., Softmax, SumNormalize, ScaledSumNormalize) steps.
  • Consolidated Routing Kernels: Consolidated common post-TopK routing pipeline steps (histogram, offsets, permutation) into a shared utility (trtllm_fused_moe_routing_common.cu), reducing code duplication across different routing methods.
  • Dynamic Parameter Passing: Updated kernel parameter structures to pass mUsePdl and mIsPow2 as runtime boolean flags instead of compile-time template parameters, reducing the number of kernel instantiations.
  • Dtype Flexibility for Routing Inputs: Enhanced flexibility for routing logits and bias data types, allowing both float32 and bfloat16 inputs across various MoE backends.
  • Expanded Test Coverage: Added comprehensive test cases for the new Default, SigmoidRenorm, and Sigmoid routing methods, along with tests verifying the data type flexibility for routing inputs.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request refactors MoE routing kernels by introducing a new policy-based dispatch system, consolidating common routing logic into trtllm_fused_moe_routing_common.cu and trtllm_fused_moe_routing_custom.cu, and removing the trtllm_fused_moe_routing_renormalize.cu file. It also adds new SigmoidRenorm and Sigmoid routing methods, and introduces a norm_topk_prob parameter to control top-K probability normalization. Review comments highlight concerns about the removal of specific dtype checks for routing_logits in trtllm_fused_moe_kernel_launcher.cu, suggesting that specific checks for each routing method should be maintained. There are also suggestions to use data_ptr() directly instead of static_cast<float*> for routing_logits to prevent type casting issues, and to use loadScalar for routing_bias to handle types correctly. Additionally, feedback points out that checks for routing_bias dimensions and shape should be conditional if not all methods use it, and that hardcoded expert tier values in trtllm_fused_moe_routing_deepseek.cu are not scalable and should be configurable. Minor comment inaccuracies regarding mUsePdl and mPdlOverlapWithNext in trtllm_fused_moe_routing_custom.cu were also noted.

I am having trouble creating individual review comments. Click here to see my feedback.

csrc/trtllm_fused_moe_kernel_launcher.cu (1799-1804)

high

The removal of the specific dtype checks for routing_logits based on RoutingMethodType is concerning. It's important to ensure that all routing methods now correctly handle both dl_float32 and dl_bfloat16 for routing_logits. If certain routing methods still require a specific dtype, this change could introduce errors. It's better to have specific checks for each routing method.

csrc/trtllm_fused_moe_kernel_launcher.cu (393-397)

high

The addition of mRoutingLogitsDtype and norm_topk_prob as arguments to the routing_runner.run function call is correct, but it's crucial to ensure that these parameters are correctly passed and handled in all subsequent calls to this function throughout the codebase. It's important to verify that the data types and values are consistent with the expected behavior of the routing kernel.

csrc/trtllm_fused_moe_kernel_launcher.cu (904-907)

high

The static_cast<float*> here is concerning. It's better to use data_ptr() directly and let the kernel handle the type. This is especially important if routing_logits is not always a float*. This could lead to incorrect memory access or type casting issues.

      args->routing_logits = static_cast<float*>(routing_logits.value().data_ptr());

csrc/trtllm_fused_moe_routing_deepseek.cu (162-165)

medium

Using loadScalar is better than static_cast<float>(params.mPtrRoutingBias[threadExpert]) as it handles the type correctly. This avoids potential type casting issues.

                     ? static_cast<OutputT>(
                           loadScalar(params.mPtrRoutingBias, threadExpert, params.mDtypeBias))
                     : invalidScore;

csrc/trtllm_fused_moe_kernel_launcher.cu (1735-1736)

medium

The type validation for routing_logits is now more general, allowing both dl_float32 and dl_bfloat16. This is good for flexibility, but it's crucial to ensure that the kernel implementation correctly handles both data types. Add a comment to explain why the type check is now more permissive.

  TVM_FFI_ICHECK(routing_logits.dtype() == dl_float32 || routing_logits.dtype() == dl_bfloat16)
      << "FP8 per-tensor MoE: routing_logits must be float or bfloat16."; // Allow both float32 and bfloat16

csrc/trtllm_fused_moe_kernel_launcher.cu (929-930)

medium

Using routing_logits.value().data_ptr() is better than static_cast<float*>(routing_logits.value().data_ptr()) as it avoids potential type casting issues. The kernel should handle the type appropriately.

      args->routing_logits = routing_logits.value().data_ptr();

csrc/trtllm_fused_moe_kernel_launcher.cu (1840-1846)

medium

The addition of checks for routing_bias dimensions and shape is good for robustness. However, it's important to ensure that these checks are consistent with the expected behavior of all routing methods that use routing_bias. If some methods don't use it, the check should be conditional.

csrc/trtllm_fused_moe_routing_deepseek.cu (67-77)

medium

The use of hardcoded values for expert tiers is not scalable. These values should be configurable or derived from the input data to allow for flexibility in different MoE configurations. Consider using a function or a lookup table to determine the appropriate expert tier based on the number of experts.

csrc/trtllm_fused_moe_kernel_launcher.cu (881-892)

medium

Adding RoutingMethodType::Sigmoid to this conditional block is correct for enabling the new routing method. However, it's important to ensure that the logic within this block is appropriate for all routing methods included, and that the comment accurately reflects the supported top_k values for all methods in the group. Consider updating the comment to be more general, or adding separate checks with distinct comments for each routing method if their requirements diverge.

      TVM_FFI_ICHECK(args->top_k <= 10 && args->top_k > 0)
          << "Current routing kernel (no groups) only supports top_k<=10 && top_k>0.";

csrc/trtllm_fused_moe_routing_custom.cu (48-50)

medium

When MaxNumExperts > 1024, the comment says the code caps the actual thread count at 1024. However, the code uses NumThreadsBlock which is assigned to MaxNumExperts <= 1024 ? MaxNumExperts : 1024. This means that the code is indeed capping the thread count at 1024. However, the code also says that each thread handles multiple experts. This needs to be verified to ensure that the thread is indeed handling multiple experts.

csrc/trtllm_fused_moe_routing_custom.cu (82-84)

medium

Using params.mUsePdl is better than KernelParams::UsePdl as it uses the runtime value instead of the compile time value.

  if (params.mUsePdl) {

csrc/trtllm_fused_moe_routing_custom.cu (311-312)

medium

Using params.mUsePdl is better than KernelParams::UsePdl as it uses the runtime value instead of the compile time value.

  if (params.mUsePdl) {

csrc/trtllm_fused_moe_routing_custom.cu (571)

medium

The comment is not accurate. The code is checking for data.mUsePdl not KernelParams::UsePdl.

csrc/trtllm_fused_moe_routing_custom.cu (600)

medium

The comment is not accurate. The code is checking for mutableData.mPdlOverlapWithNext not data.mPdlOverlapWithNext.

csrc/trtllm_fused_moe_routing_custom.cu (634)

medium

The comment is not accurate. The code is checking for mutableData.mPdlOverlapWithNext not data.mPdlOverlapWithNext.

csrc/trtllm_fused_moe_routing_custom.cu (640)

medium

The comment is not accurate. The code is checking for mutableData.mPdlOverlapWithNext not data.mPdlOverlapWithNext.

Signed-off-by: EdalatiAli <aliedalati@cohere.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants