Skip to content

Refactor the routing part#2803

Open
ChristinaZ wants to merge 1 commit intoflashinfer-ai:mainfrom
ChristinaZ:refactor_routing
Open

Refactor the routing part#2803
ChristinaZ wants to merge 1 commit intoflashinfer-ai:mainfrom
ChristinaZ:refactor_routing

Conversation

@ChristinaZ
Copy link
Contributor

@ChristinaZ ChristinaZ commented Mar 17, 2026

📌 Description

Refactor the routing part of the trtllm-gen.

Execution Flow

Each routing method follows the same high-level pattern:

  1. TopK selection — Compute top-K experts per token from routing scores.
  2. Histogram — Count how many tokens are assigned to each expert.
  3. Offsets + Permutation — Prefix-scan expert counts to get offsets; build permutation indices mapping expanded token slots → padded expert-sorted positions.
  4. GEMM config — Write ctaIdxXyToBatchIdx, ctaIdxXyToMnLimit, numNonExitingCtas so the downstream batched GEMM knows how to partition work.

Depending on token count, different code paths are selected:

Token Count Code Path
≤ 4 Single-block kernel (fuses all steps incl. topK)
≤ cluster capacity Single-cluster kernel (uses distributed smem, incl. topK)
≤ coop capacity Scores→topK kernel + cooperative kernel (fuses histogram+offsets)
Large Scores→topK kernel + histogram kernel + offsets kernel

When topK is pre-computed (mPtrTopKIds or mPtrTopKPacked), the first two paths skip topK,
and the coop/large paths skip the scores→topK kernel.

runPostTopKPipeline<DataType>() handles the permutation pipeline when topK is already computed
(e.g., by DeepSeek's grouped main kernel). It converts any routing method's Data to
routingCustom::Data and dispatches through the appropriate token-count path (single-block,
single-cluster, coop, or multi-kernel). This avoids duplicating permutation code across methods.

Routing Methods and TierList Configuration

RoutingMethodType → Policy Mapping

Each RoutingMethodType (defined in runner.h) maps to a specific kernel path and policy
combination. The routingCustom method uses a policy-based design where expert selection logic is injected as a compile-time ExpertSelectPolicy template parameter, providing zero runtime overhead and high extensibility.

The default TopKExpertSelect<PreprocessPolicy, PostprocessPolicy> wraps the traditional preprocess → topK → postprocess pattern, while users can write completely custom policies that bypass this pattern (e.g., lookup-table-based expert selection).

Each policy owns its runtime data through a nested Params<OutputT> struct. When a policy doesn't need extra data, its Params is empty and costs zero registers. This avoids paying for unused fields (e.g., a routing bias pointer) in policy combinations that don't need them.

TierList and PolicyTraits

Each routing kernel is templated on MaxNumExperts and MaxNumTopExperts. These determine
shared memory sizes, loop bounds, and the __launch_bounds__ thread count. To avoid compiling
every combination, each policy declares which (MaxNumExperts, MaxTopK) pairs it supports via
a PolicyTraits specialization:

template <>
struct PolicyTraits<SigmoidBiasPreprocess, ScaledSumNormalizePostprocess>
{
    using Pairs = TierList<
        Tier<128, 8>,   // ≤128 experts, topK ≤ 8
        Tier<256, 8>,   // ≤256 experts, topK ≤ 8
        Tier<384, 8>,   // ≤384 experts, topK ≤ 8
        Tier<512, 8>,   // ≤512 experts, topK ≤ 8
        Tier<512, 22>   // ≤512 experts, topK ≤ 22  (Nemotron Super V3)
    >;
};

Dispatch: dispatchTierPairs iterates the TierList from first to last, picking the
first Tier<E, K> where numExperts ≤ E AND topK ≤ K. This means tighter tiers
must come first (sorted by E ascending, then K ascending within equal E).

Adding support for a new model: If a new model has a (numExperts, topK) combination
not covered by any existing tier, add a Tier<E, K> to the appropriate PolicyTraits
specialization. No other changes are needed — the dispatch macros are generic.

Available Policies

Policy Type Params Fields Description
NoOpPreprocess Pre (empty) Pass-through; BaseType = InputT
SoftmaxPreprocess Pre (empty) Softmax over all expert scores; BaseType = float
SigmoidPreprocess Pre (empty) sigmoid(score) (no bias); BaseType = float
SigmoidBiasPreprocess Pre ptrRoutingBias, dtypeBias sigmoid(score) + bias[expertIdx]; BaseType = float
NoOpPostprocess Post (empty) No transformation
SoftmaxPostprocess Post (empty) Softmax over selected top-K scores
SumNormalizePostprocess Post normTopkProb Divide top-K scores by their sum
ScaledSumNormalizePostprocess Post ptrRoutingBias, dtypeBias, routeScale, sumEpsilon Recover sigmoid, normalize by sum, apply scale

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

Release Notes

  • New Features

    • Added norm_topk_prob parameter to control top-K probability normalization in MoE routing.
    • Introduced new SigmoidRenorm routing method type for enhanced routing flexibility.
    • Extended support for routing logits with FP32 and BF16 data types.
  • Improvements

    • Enhanced routing kernel efficiency across all MoE precision variants (BF16, FP8, FP4, MXInt4).
    • Refactored routing architecture for improved modularity and performance.
  • Bug Fixes

    • Improved handling of routing bias and logits dtype validation.

Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 17, 2026

📝 Walkthrough

Walkthrough

This PR introduces a policy-driven routing framework for MoE implementations, adding support for normalization of top-K probabilities and dynamic routing logits dtype tracking. It refactors routing kernel architectures across custom, DeepSeek, and Llama4 paths, consolidates post-topK pipeline logic, and updates launcher APIs to propagate new parameters through initialization and execution flows.

Changes

Cohort / File(s) Summary
Launcher API Extensions
csrc/trtllm_fused_moe_kernel_launcher.cu
Added norm_topk_prob boolean member and mRoutingLogitsDtype tracking; extended init_common and all launcher subclass init methods with norm_topk_prob parameter; propagated routing logits dtype through initialization and routing preparation paths.
Routing Framework & Policies
include/flashinfer/trtllm/fused_moe/RoutingCustomPolicy.cuh
Introduced comprehensive policy-driven routing abstractions with preprocess (NoOp, Softmax, Sigmoid, SigmoidBias) and postprocess (NoOp, Softmax, SumNormalize, ScaledSumNormalize) policies; added TopKExpertSelect composition and dispatch macros for policy-based kernel launches.
Routing Kernel Architecture
include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh, include/flashinfer/trtllm/fused_moe/RoutingKernel.h
Added loadScalar device helper for dtype-aware scalar reading; refactored KernelParams with policy-driven ExpertSelectParams; shifted to multi-expert-per-thread design; added cooperative kernel support (SM90+) with grid synchronization and PDL overlap control.
Routing Kernel Top-K Reduction
include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh
Renamed MaxNumTopK to MaxSupportedTopExperts; refactored Sort infrastructure with dynamic Bitonic/Odd-Even selection; updated reduceTopK API and signature constraints; improved architecture detection for fast-reduction capabilities.
Routing Launch Macros
include/flashinfer/trtllm/fused_moe/RoutingDevKernel.h, include/flashinfer/trtllm/fused_moe/DevKernel.h
Added LAUNCH_PDL_ROUTING, LAUNCH_ROUTING_LLAMA4, and policy-based dispatch macros; consolidated routing kernel launch configuration with PDL serialization and programmatic launch support; removed legacy routing macros in favor of centralized RoutingDevKernel definitions.
Custom Routing Implementation
csrc/trtllm_fused_moe_routing_custom.cu
New file implementing multi-path fused routing: block kernel (≤4 tokens), cluster kernel (≤256 tokens, SM90+), histogram/coop/multi-kernel pathways; added run() orchestrator selecting paths based on input configuration and hardware capabilities.
Shared Post-TopK Pipeline
csrc/trtllm_fused_moe_routing_common.cu
New file centralizing post-topK pipeline logic with runPostTopKPipeline template; dispatches between single-block, single-cluster, cooperative, and multi-kernel paths based on token count and overlap control.
DeepSeek Routing Consolidation
csrc/trtllm_fused_moe_routing_deepseek.cu
Restructured with unified launcher wrappers (launchMainKernel, launchClusterKernel, launchCoopKernel); added coop path delegation to routingCustom; introduced dispatch macros for top-K and expert count routing; replaced internal impl structure with public run() entry point.
Llama4 Routing Refinement
csrc/trtllm_fused_moe_routing_llama4.cu
Updated header year; replaced static constexpr getMaxNumExperts with regular function; introduced cluster-aware scaling in token/offset calculations; refactored to use data.mUsePdl and mIsPow2; added early delegation to shared post-topK pipeline when TopK results provided.
Routing Renormalize Removal
csrc/trtllm_fused_moe_routing_renormalize.cu
Deleted entire file; functionality consolidated into custom routing path with policy-based approach via RoutingCustomPolicy.cuh.
Runner & Executor Updates
csrc/trtllm_fused_moe_runner.cu, include/flashinfer/trtllm/fused_moe/runner.h
Extended Runner::run signature with dtypeLogits and normTopkProb parameters; updated routing method dispatch to include SigmoidRenorm (value 6); remapped routing paths using policy abstractions; replaced direct routingRenormalize calls with routingCustom::run.
Python API & Bindings
flashinfer/fused_moe/core.py, flashinfer/jit/fused_moe.py
Added SigmoidRenorm routing method (value 6) and updated Unspecified to value 7; threaded norm_topk_prob parameter through all MoE operation signatures (BF16, FP8 variants, FP4, MxInt4); replaced renormalize.cu with custom.cu and common.cu in SM100 module compilation.
CUDA Utilities
include/flashinfer/trtllm/common/cudaUtils.h
Added getSMVersion() and getMultiProcessorCount() inline device-query helpers to namespace tensorrt_llm::common.
Test Infrastructure
tests/moe/test_trtllm_gen_fused_moe.py
Propagated norm_topk_prob through MoE runtime config and kernel invocations; added routing_reference_default (Softmax→TopK) and routing_reference_sigmoid_renorm (Sigmoid→TopK→optional renormalization); expanded test coverage for Default and SigmoidRenorm routing methods with multiple backends.

Sequence Diagram

sequenceDiagram
    participant Launcher as MoE Launcher
    participant Runner as Runner::run()
    participant RoutingDispatch as Routing Dispatch
    participant PolicyRouter as Policy-Based Router
    participant PostTopK as Post-TopK Pipeline
    participant Kernel as Routing Kernel

    Launcher->>Runner: Call with routing config<br/>(norm_topk_prob, dtypeLogits)
    activate Runner
    Runner->>RoutingDispatch: Route based on method<br/>(Default, SigmoidRenorm, etc.)
    activate RoutingDispatch
    
    alt Has TopK Results & No Scores
        RoutingDispatch->>PostTopK: Delegate to shared pipeline
        activate PostTopK
        PostTopK->>Kernel: Launch post-topK kernel<br/>(histogram/offsets)
    else Has Scores
        RoutingDispatch->>PolicyRouter: Select preprocess/postprocess<br/>policies based on method
        activate PolicyRouter
        PolicyRouter->>Kernel: LAUNCH_ROUTING_FOR_POLICY<br/>with composed policies
        deactivate PolicyRouter
    end
    
    Kernel-->>PostTopK: Complete post-topK
    deactivate PostTopK
    Kernel-->>RoutingDispatch: Routing complete
    deactivate RoutingDispatch
    RoutingDispatch-->>Runner: Return routing results
    deactivate Runner
    Runner-->>Launcher: Provide expert indices<br/>& histogram
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Suggested labels

run-ci, op: moe, op: moe-routing

Suggested reviewers

  • aleozlx
  • cyx-6
  • djmmoss
  • yzh119
  • joker-eph
  • jiahanc

Poem

🐰 Behold! The routing pathways bloom anew,
With policies composed and kernels refactored too,
Top-K normalization hops in with graceful stride,
While post-pipelines merge custom streams side by side!
SM90 cooperates, and renormalize bids adieu—
A hoppy refactor, through and through! 🎪

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 42.86% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title 'Refactor the routing part' is vague and generic, using non-descriptive language that doesn't clearly convey the specific nature or scope of the substantial changes made. Replace with a more specific title such as 'Refactor MoE routing with policy-based expert selection and unified post-TopK pipeline' to better reflect the major architectural changes.
✅ Passed checks (1 passed)
Check name Status Explanation
Description check ✅ Passed The description provides detailed technical documentation of the refactoring, including execution flow, policy design, tier configuration, and available policies, addressing the core changes comprehensively.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

You can make CodeRabbit's review stricter and more nitpicky using the `assertive` profile, if that's what you prefer.

Change the reviews.profile setting to assertive to make CodeRabbit's nitpick more issues in your PRs.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the routing component of the trtllm-gen, focusing on improving the flexibility, extensibility, and performance of MoE (Mixture of Experts) configurations. It introduces a policy-based design for expert selection, adds support for new routing methods, and enhances data type flexibility. The changes aim to provide a more efficient and adaptable routing mechanism for various models.

Highlights

  • Routing Logic Refactor: The routing logic has been refactored to enhance extensibility and performance, introducing a policy-based design for expert selection.
  • New Routing Methods: Added support for new routing methods, including Default and SigmoidRenorm, expanding the flexibility of MoE configurations.
  • Data Type Flexibility: Improved data type flexibility by allowing routing logits and bias to accept both float32 and bfloat16, enhancing compatibility with various models.
  • Performance Improvements: Optimized performance through cooperative kernel launches and multi-kernel pipelines, adapting to different token counts for efficient processing.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/trtllm_fused_moe_kernel_launcher.cu
    • Added norm_topk_prob as a parameter to control top-K score normalization.
    • Introduced mRoutingLogitsDtype to support different data types for routing logits.
    • Modified init_common to accept norm_topk_prob.
    • Updated kernel launches to pass mRoutingLogitsDtype and norm_topk_prob.
  • csrc/trtllm_fused_moe_routing_common.cu
    • Implemented a shared post-topK pipeline for all routing methods to avoid code duplication.
    • Added runPostTopKPipeline function to handle path selection based on token count.
  • csrc/trtllm_fused_moe_routing_custom.cu
    • Introduced custom routing kernels and policies for flexible expert selection.
    • Implemented block, cluster, and multi-kernel pipelines for different token counts.
    • Added support for various preprocessing and postprocessing policies, such as Softmax, Sigmoid, and SumNormalize.
  • csrc/trtllm_fused_moe_routing_deepseek.cu
    • Modified the routingMainKernel to support different data types for routing logits and bias.
    • Updated kernel launches to pass the forceFloatInput flag.
    • Leveraged the shared post-topK pipeline from routingCustom for improved code reuse.
  • csrc/trtllm_fused_moe_routing_llama4.cu
    • Modified the routingIndicesWarpKernel to handle different data types for routing logits.
    • Leveraged the shared post-topK pipeline from routingCustom for improved code reuse.
  • csrc/trtllm_fused_moe_runner.cu
    • Modified the run function to support different routing methods and data types.
    • Updated kernel launches to pass the dtypeLogits and normTopkProb parameters.
    • Replaced routingRenormalize with routingCustom for enhanced flexibility.
  • flashinfer/fused_moe/core.py
    • Added SigmoidRenorm to the RoutingMethodType enum.
    • Modified forward functions to accept and pass the norm_topk_prob parameter.
    • Updated trtllm_bf16_moe, trtllm_fp8_per_tensor_scale_moe, trtllm_fp8_block_scale_moe, trtllm_fp4_block_scale_moe, and trtllm_mxint4_block_scale_moe to include norm_topk_prob in the function signatures.
  • include/flashinfer/trtllm/common/cudaUtils.h
    • Added getSMVersion and getMultiProcessorCount inline functions to query device properties.
  • include/flashinfer/trtllm/fused_moe/DevKernel.h
    • Moved routing-specific launch macros to RoutingDevKernel.h.
    • Updated LAUNCH_ROUTING_FOR_POLICY to use the new template signature with runtime isPow2/UsePdl.
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
    • Added mPdlOverlapWithNext to DataBase to control overlap with the next kernel.
    • Modified DataBase to include mDtypeInput.
    • Updated loadScalar to support different data types for routing bias.
    • Updated KernelParamsBase to include MaxNumTopExperts.
    • Updated KernelParamsBase to include mUsePdl and mIsPow2.
    • Moved routing-specific launch macros to RoutingDevKernel.h.
  • include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh
    • Updated the implementation of reduceTopK to use fast redux.sync.max.u32 on SM100+ for 32-bit packed types.
    • Updated the implementation of reduceTopK to support candidates number less than or equal to 64*32=2048
Activity
  • Implemented a policy-based design for expert selection.
  • Added support for new routing methods, including Default and SigmoidRenorm.
  • Enhanced data type flexibility by allowing routing logits and bias to accept both float32 and bfloat16.
  • Optimized performance through cooperative kernel launches and multi-kernel pipelines.
  • Added a shared post-topK pipeline for all routing methods to avoid code duplication.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh (1)

523-543: ⚠️ Potential issue | 🔴 Critical

Always decode idx from packed Top-K input.

In the packed-input branch, idx is only assigned when mPtrTopKWeights != nullptr. That makes the histogram use an uninitialized expert id whenever the routed path skips a separate weights buffer, and it also diverges from the coop helper below, which correctly reads packed ids regardless of whether weights are materialized.

🐛 Suggested fix
   auto loopBody = [&](int expandedIdx) {
-    PackedScoreIdx<OutputT> scoreIdx;
     int idx;
     if (params.mPtrTopKIds != nullptr) {
       idx = params.mPtrTopKIds[expandedIdx];
     } else {
-      // If params.mPtrTopKIds != nullptr, we don't need to store the weights
-      if (params.mPtrTopKWeights != nullptr) {
-        scoreIdx = params.mPtrTopKPacked[expandedIdx];
-        idx = scoreIdx.idx;
-        params.mPtrTopKWeights[expandedIdx] = static_cast<OutputT>(scoreIdx.score);
-      }
+      auto scoreIdx = params.mPtrTopKPacked[expandedIdx];
+      idx = scoreIdx.idx;
+      if (params.mPtrTopKWeights != nullptr) {
+        params.mPtrTopKWeights[expandedIdx] = static_cast<OutputT>(scoreIdx.score);
+      }
     }
     // check whether this expert is local to our GPU at all and ignore if not
     auto localExpertIdx = idx - params.mLocalExpertsStartIdx;
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh` around lines 523 -
543, The lambda loopBody uses idx uninitialized when mPtrTopKIds==nullptr but
params.mPtrTopKWeights==nullptr; always decode the packed top-k entry instead of
conditioning it on mPtrTopKWeights: when params.mPtrTopKIds is nullptr, read
PackedScoreIdx<OutputT> scoreIdx = params.mPtrTopKPacked[expandedIdx]; set idx =
scoreIdx.idx unconditionally, and then only write
params.mPtrTopKWeights[expandedIdx] = static_cast<OutputT>(scoreIdx.score) if
params.mPtrTopKWeights != nullptr; this ensures loopBody, the localExpertIdx
computation and atomicAdd(&smemExpertCount[idx], 1) use a valid decoded idx.
🧹 Nitpick comments (1)
csrc/trtllm_fused_moe_routing_deepseek.cu (1)

546-548: Avoid mutating caller-owned Data pointer fields in-place.

Setting data.mPtrExpertCounts = nullptr mutates persistent input state and can break subsequent calls that reuse the same Data object.

♻️ Suggested refactor
-  if (data.mPtrPermutedIdxSize != nullptr) {
+  if (data.mPtrPermutedIdxSize != nullptr) {
+    Data launchData = data;
     bool const useSingleCluster = data.mNumTokens <= 1024;
     if (!useSingleCluster) {
       FLASHINFER_CHECK(data.mPtrExpertCounts != nullptr,
                        "When `#tokens` is large, `mPtrExpertCounts` is a required input.");
     } else {
-      data.mPtrExpertCounts =
-          nullptr;  // Set it to nullptr for single-cluster code path, as it won't be used
+      launchData.mPtrExpertCounts = nullptr;
     }
@@
-      data.mPdlOverlapWithNext = false;  // Last kernel
-      launchClusterKernel(data, numThreadsHist, stream);
+      launchData.mPdlOverlapWithNext = false;  // Last kernel
+      launchClusterKernel(launchData, numThreadsHist, stream);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_routing_deepseek.cu` around lines 546 - 548, Do not
mutate the caller-owned Data struct in-place by assigning data.mPtrExpertCounts
= nullptr; instead, preserve the original by introducing a local variable (e.g.,
auto* localExpertCounts = data.mPtrExpertCounts) and use/override that local
variable for the single-cluster code path where a nullptr is needed; update all
subsequent uses in this scope to reference localExpertCounts (and ensure any
ownership/freeing logic still respects the original Data ownership contract) so
the passed-in Data remains unchanged for callers.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/trtllm_fused_moe_routing_custom.cu`:
- Around line 601-612: The coop eligibility check allows data.mNumExperts <=
1024 but the coop kernel (launched by launchCoopKernel) only supports up to
NumExperts576Experts (576); update the canUseCoop condition to use the actual
supported limit (e.g., replace the 1024 check with 576/NumExperts576Experts)
and/or add an explicit runtime guard just before calling launchCoopKernel that
verifies data.mNumExperts <= NumExperts576Experts to prevent attempting to
launch an unsupported coop kernel.
- Around line 445-479: The coop-launch branch in launchCoopKernel silently skips
launching any kernel when data.mNumExperts > NumExperts576Experts (only logging
FLASHINFER_WARN), leaving routing uninitialized; fix by either failing fast
(change the else to return an error/throw an exception or call FLASHINFER_ERROR)
or implement a fallback that launches a non-coop/multi-kernel routing path
(e.g., call LAUNCH_ROUTING_WITH_POLICIES with /*coopLaunch=*/false and the
non-coop kernel symbol such as routingIndicesKernel or the existing multi-kernel
dispatcher) so routing is always produced; update the else block in
launchCoopKernel accordingly and ensure any callers handle the error/exception
if you choose the fail-fast option.

In `@flashinfer/fused_moe/core.py`:
- Line 1347: The fake-op function signatures (_fake_trtllm_bf16_moe,
_fake_trtllm_fp8_per_tensor_scale_moe, _fake_trtllm_fp8_block_scale_moe,
_fake_trtllm_fp4_block_scale_moe, _fake_trtllm_mxint4_block_scale_moe) must be
updated so their parameter lists exactly match their real-op counterparts: add
the trailing norm_topk_prob boolean parameter and any other missing parameters
present in the real op schemas; adjust the function definitions and any internal
arg handling to accept and forward these parameters (preserve default values),
ensuring the fake-op signatures are identical to the real ops to avoid schema
drift during torch.compile/meta-tensor dispatch.

In `@include/flashinfer/trtllm/common/cudaUtils.h`:
- Around line 276-291: Both getSMVersion() and getMultiProcessorCount() silently
ignore CUDA errors; wrap every CUDA call inside these functions with
FLASHINFER_CHECK() so failures are propagated instead of returning bogus values.
Specifically, in getSMVersion() call FLASHINFER_CHECK(cudaGetDevice(&device))
and FLASHINFER_CHECK(cudaDeviceGetAttribute(&sm_major,
cudaDevAttrComputeCapabilityMajor, device)) /
FLASHINFER_CHECK(cudaDeviceGetAttribute(&sm_minor,
cudaDevAttrComputeCapabilityMinor, device)); in getMultiProcessorCount() call
FLASHINFER_CHECK(cudaGetDevice(&device)) and
FLASHINFER_CHECK(cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount,
device)) so any CUDA error surfaces via the existing FLASHINFER_CHECK handling.

In `@include/flashinfer/trtllm/fused_moe/RoutingCustomPolicy.cuh`:
- Around line 180-203: SumNormalizePostprocess::apply can produce a zero
reduction when reused for SigmoidRenorm, causing NaNs on division; fix it by
clamping the computed denominator before dividing: after computing sum (the
cg::reduce result) ensure sum = max(sum, epsilon) with a small float epsilon
(e.g., 1e-6f) so warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum is safe;
update the apply method in SumNormalizePostprocess (and the equivalent block
around the other occurrence noted) to perform this clamp when
params.normTopkProb is true or whenever the reduction result may be zero.

In `@include/flashinfer/trtllm/fused_moe/RoutingKernel.h`:
- Around line 327-343: The Data struct sets mDtypeOutput to Fp32 while the
comment and other routing structs (routingDeepSeek::Data, routingLlama4::Data)
expect Bfloat16; change the default of Data::mDtypeOutput from tg::Dtype::Fp32
to tg::Dtype::Bfloat16 (or explicitly document why Fp32 is required) so defaults
are consistent—update the initializer of mDtypeOutput in the Data struct and
adjust any related documentation/comments to match.

---

Outside diff comments:
In `@include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh`:
- Around line 523-543: The lambda loopBody uses idx uninitialized when
mPtrTopKIds==nullptr but params.mPtrTopKWeights==nullptr; always decode the
packed top-k entry instead of conditioning it on mPtrTopKWeights: when
params.mPtrTopKIds is nullptr, read PackedScoreIdx<OutputT> scoreIdx =
params.mPtrTopKPacked[expandedIdx]; set idx = scoreIdx.idx unconditionally, and
then only write params.mPtrTopKWeights[expandedIdx] =
static_cast<OutputT>(scoreIdx.score) if params.mPtrTopKWeights != nullptr; this
ensures loopBody, the localExpertIdx computation and
atomicAdd(&smemExpertCount[idx], 1) use a valid decoded idx.

---

Nitpick comments:
In `@csrc/trtllm_fused_moe_routing_deepseek.cu`:
- Around line 546-548: Do not mutate the caller-owned Data struct in-place by
assigning data.mPtrExpertCounts = nullptr; instead, preserve the original by
introducing a local variable (e.g., auto* localExpertCounts =
data.mPtrExpertCounts) and use/override that local variable for the
single-cluster code path where a nullptr is needed; update all subsequent uses
in this scope to reference localExpertCounts (and ensure any ownership/freeing
logic still respects the original Data ownership contract) so the passed-in Data
remains unchanged for callers.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9d2337f3-d0f7-4f7c-90b9-208b0ea8826a

📥 Commits

Reviewing files that changed from the base of the PR and between e4dc66f and 97b6f0e.

📒 Files selected for processing (18)
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • csrc/trtllm_fused_moe_routing_common.cu
  • csrc/trtllm_fused_moe_routing_custom.cu
  • csrc/trtllm_fused_moe_routing_deepseek.cu
  • csrc/trtllm_fused_moe_routing_llama4.cu
  • csrc/trtllm_fused_moe_routing_renormalize.cu
  • csrc/trtllm_fused_moe_runner.cu
  • flashinfer/fused_moe/core.py
  • flashinfer/jit/fused_moe.py
  • include/flashinfer/trtllm/common/cudaUtils.h
  • include/flashinfer/trtllm/fused_moe/DevKernel.h
  • include/flashinfer/trtllm/fused_moe/RoutingCustomPolicy.cuh
  • include/flashinfer/trtllm/fused_moe/RoutingDevKernel.h
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
  • include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh
  • include/flashinfer/trtllm/fused_moe/runner.h
  • tests/moe/test_trtllm_gen_fused_moe.py
💤 Files with no reviewable changes (1)
  • csrc/trtllm_fused_moe_routing_renormalize.cu

Comment on lines +445 to +479
void launchCoopKernel(Data const& data, int numBlocksCoop, uint32_t numThreadsHist, void* stream) {
if (data.mNumExperts <= NumExperts128Experts) {
LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel,
numBlocksCoop, numThreadsHist, /*smemSize=*/0, stream,
NoOpPreprocess, NoOpPostprocess, NumExperts128Experts,
NumTop8Experts);
} else if (data.mNumExperts <= NumExperts160Experts) {
LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel,
numBlocksCoop, numThreadsHist, /*smemSize=*/0, stream,
NoOpPreprocess, NoOpPostprocess, NumExperts160Experts,
NumTop8Experts);
} else if (data.mNumExperts <= NumExperts256Experts) {
LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel,
numBlocksCoop, numThreadsHist, /*smemSize=*/0, stream,
NoOpPreprocess, NoOpPostprocess, NumExperts256Experts,
NumTop8Experts);
} else if (data.mNumExperts <= NumExperts384Experts) {
LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel,
numBlocksCoop, numThreadsHist, /*smemSize=*/0, stream,
NoOpPreprocess, NoOpPostprocess, NumExperts384Experts,
NumTop8Experts);
} else if (data.mNumExperts <= NumExperts512Experts) {
LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel,
numBlocksCoop, numThreadsHist, /*smemSize=*/0, stream,
NoOpPreprocess, NoOpPostprocess, NumExperts512Experts,
NumTop8Experts);
} else if (data.mNumExperts <= NumExperts576Experts) {
LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel,
numBlocksCoop, numThreadsHist, /*smemSize=*/0, stream,
NoOpPreprocess, NoOpPostprocess, NumExperts576Experts,
NumTop8Experts);
} else {
FLASHINFER_WARN("Coop kernel does not support numExperts > %d", NumExperts576Experts);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Silent failure when numExperts > 576.

When numExperts exceeds NumExperts576Experts, the function only prints a warning but does not launch any kernel, leaving the routing incomplete. This could cause downstream kernels to read uninitialized data.

Consider either:

  1. Returning an error/throwing to fail fast, or
  2. Falling back to a multi-kernel path that can handle larger expert counts
🐛 Suggested fix
   } else if (data.mNumExperts <= NumExperts576Experts) {
     LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel,
                                  numBlocksCoop, numThreadsHist, /*smemSize=*/0, stream,
                                  NoOpPreprocess, NoOpPostprocess, NumExperts576Experts,
                                  NumTop8Experts);
   } else {
-    FLASHINFER_WARN("Coop kernel does not support numExperts > %d", NumExperts576Experts);
+    TVM_FFI_LOG_AND_THROW(NotImplementedError)
+        << "Coop kernel does not support numExperts > " << NumExperts576Experts
+        << ", got " << data.mNumExperts;
   }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_routing_custom.cu` around lines 445 - 479, The
coop-launch branch in launchCoopKernel silently skips launching any kernel when
data.mNumExperts > NumExperts576Experts (only logging FLASHINFER_WARN), leaving
routing uninitialized; fix by either failing fast (change the else to return an
error/throw an exception or call FLASHINFER_ERROR) or implement a fallback that
launches a non-coop/multi-kernel routing path (e.g., call
LAUNCH_ROUTING_WITH_POLICIES with /*coopLaunch=*/false and the non-coop kernel
symbol such as routingIndicesKernel or the existing multi-kernel dispatcher) so
routing is always produced; update the else block in launchCoopKernel
accordingly and ensure any callers handle the error/exception if you choose the
fail-fast option.

Comment on lines +601 to +612
static int const smMajor = tensorrt_llm::common::getSMVersion() / 10;
bool const canUseCoop = (smMajor >= 9) && (data.mNumExperts <= 1024) &&
(data.mPtrPermutedIdxSize != nullptr);
bool useCoop = false;
int numBlocksCoop = 0;

if (canUseCoop) {
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
numBlocksCoop = smCount - 8; // Reserve 8 SMs for overlapping kernels
int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK;
useCoop = (data.mNumTokens <= maxTokensCoop);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Mismatch between coop eligibility check and actual kernel support.

The canUseCoop check allows numExperts <= 1024, but launchCoopKernel only supports up to NumExperts576Experts (576). If numExperts is between 577 and 1024, useCoop could be true but the coop kernel would silently fail.

🐛 Suggested fix
     if (canUseCoop) {
       static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
       numBlocksCoop = smCount - 8;  // Reserve 8 SMs for overlapping kernels
       int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK;
-      useCoop = (data.mNumTokens <= maxTokensCoop);
+      useCoop = (data.mNumTokens <= maxTokensCoop) && (data.mNumExperts <= NumExperts576Experts);
     }

Alternatively, update line 602 to use the actual limit:

-    bool const canUseCoop = (smMajor >= 9) && (data.mNumExperts <= 1024) &&
+    bool const canUseCoop = (smMajor >= 9) && (data.mNumExperts <= NumExperts576Experts) &&
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
static int const smMajor = tensorrt_llm::common::getSMVersion() / 10;
bool const canUseCoop = (smMajor >= 9) && (data.mNumExperts <= 1024) &&
(data.mPtrPermutedIdxSize != nullptr);
bool useCoop = false;
int numBlocksCoop = 0;
if (canUseCoop) {
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
numBlocksCoop = smCount - 8; // Reserve 8 SMs for overlapping kernels
int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK;
useCoop = (data.mNumTokens <= maxTokensCoop);
}
static int const smMajor = tensorrt_llm::common::getSMVersion() / 10;
bool const canUseCoop = (smMajor >= 9) && (data.mNumExperts <= 1024) &&
(data.mPtrPermutedIdxSize != nullptr);
bool useCoop = false;
int numBlocksCoop = 0;
if (canUseCoop) {
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
numBlocksCoop = smCount - 8; // Reserve 8 SMs for overlapping kernels
int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK;
useCoop = (data.mNumTokens <= maxTokensCoop) && (data.mNumExperts <= NumExperts576Experts);
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_routing_custom.cu` around lines 601 - 612, The coop
eligibility check allows data.mNumExperts <= 1024 but the coop kernel (launched
by launchCoopKernel) only supports up to NumExperts576Experts (576); update the
canUseCoop condition to use the actual supported limit (e.g., replace the 1024
check with 576/NumExperts576Experts) and/or add an explicit runtime guard just
before calling launchCoopKernel that verifies data.mNumExperts <=
NumExperts576Experts to prevent attempting to launch an unsupported coop kernel.

do_finalize: bool = True,
enable_pdl: Optional[bool] = None,
tune_max_num_tokens: int = 8192,
norm_topk_prob: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
python - <<'PY'
import ast
from pathlib import Path

path = Path("flashinfer/fused_moe/core.py")
module = ast.parse(path.read_text())
defs = {node.name: node for node in ast.walk(module) if isinstance(node, ast.FunctionDef)}

pairs = [
    ("trtllm_bf16_moe_op", "_fake_trtllm_bf16_moe"),
    ("trtllm_fp8_per_tensor_scale_moe_op", "_fake_trtllm_fp8_per_tensor_scale_moe"),
    ("trtllm_fp8_block_scale_moe_op", "_fake_trtllm_fp8_block_scale_moe"),
    ("trtllm_fp4_block_scale_moe_op", "_fake_trtllm_fp4_block_scale_moe"),
    ("trtllm_mxint4_block_scale_moe_op", "_fake_trtllm_mxint4_block_scale_moe"),
]

for real, fake in pairs:
    rp = [arg.arg for arg in defs[real].args.args]
    fp = [arg.arg for arg in defs[fake].args.args]
    print(f"{real} vs {fake}: {'OK' if rp == fp else 'MISMATCH'}")
    if rp != fp:
        print("  real:", rp)
        print("  fake:", fp)
PY

Repository: flashinfer-ai/flashinfer

Length of output: 5080


Add norm_topk_prob (and missing parameters) to all fake-op signatures to match their real op counterparts.

The real custom ops now accept norm_topk_prob as a trailing parameter, but the corresponding fake ops (_fake_trtllm_bf16_moe, _fake_trtllm_fp8_per_tensor_scale_moe, _fake_trtllm_fp8_block_scale_moe, _fake_trtllm_fp4_block_scale_moe, _fake_trtllm_mxint4_block_scale_moe) retain their old signatures. Fake-op schemas must exactly match their real ops for torch.compile to correctly dispatch through meta tensors. Schema drift will cause failures when callers pass the new argument.

Also applies to: 1512-1512, 1671-1671, 1883-1883, 2106-2106

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` at line 1347, The fake-op function signatures
(_fake_trtllm_bf16_moe, _fake_trtllm_fp8_per_tensor_scale_moe,
_fake_trtllm_fp8_block_scale_moe, _fake_trtllm_fp4_block_scale_moe,
_fake_trtllm_mxint4_block_scale_moe) must be updated so their parameter lists
exactly match their real-op counterparts: add the trailing norm_topk_prob
boolean parameter and any other missing parameters present in the real op
schemas; adjust the function definitions and any internal arg handling to accept
and forward these parameters (preserve default values), ensuring the fake-op
signatures are identical to the real ops to avoid schema drift during
torch.compile/meta-tensor dispatch.

Comment on lines +276 to +291
inline int getSMVersion() {
int device{-1};
cudaGetDevice(&device);
int sm_major = 0;
int sm_minor = 0;
cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device);
cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device);
return sm_major * 10 + sm_minor;
}

inline int getMultiProcessorCount() {
int device{-1};
cudaGetDevice(&device);
int count = 0;
cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device);
return count;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Locate the cudaUtils.h file
find . -name "cudaUtils.h" -type f

Repository: flashinfer-ai/flashinfer

Length of output: 173


🏁 Script executed:

# Check file contents around lines 276-291
if [ -f "include/flashinfer/trtllm/common/cudaUtils.h" ]; then
  sed -n '276,295p' include/flashinfer/trtllm/common/cudaUtils.h
fi

Repository: flashinfer-ai/flashinfer

Length of output: 606


🏁 Script executed:

# Search for FLASHINFER_CHECK macro definition in the repository
rg "define FLASHINFER_CHECK|FLASHINFER_CHECK\(" --type c --type cpp -A 3 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 3430


🏁 Script executed:

# Check what headers are included at the top of cudaUtils.h
if [ -f "include/flashinfer/trtllm/common/cudaUtils.h" ]; then
  head -50 include/flashinfer/trtllm/common/cudaUtils.h
fi

Repository: flashinfer-ai/flashinfer

Length of output: 1634


🏁 Script executed:

# Search for similar CUDA helper functions in flashinfer to see how they handle errors
rg "cudaGetDevice|cudaDeviceGetAttribute" --type c --type cpp -B 2 -A 2 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 5753


Surface CUDA query failures instead of silently returning bogus device info.

The functions getSMVersion() and getMultiProcessorCount() ignore CUDA status codes from cudaGetDevice() and cudaDeviceGetAttribute(). If the thread lacks an active device context, these helpers silently degrade to device = -1 or count = 0, causing any downstream dispatch logic to take incorrect paths. This pattern is inconsistent with error-checked versions of the same functions in include/flashinfer/trtllm/common.h (uses CUDACHECK) and csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h (uses check_cuda_error()). The FLASHINFER_CHECK macro is already available in this file and used elsewhere (e.g., in cudaStreamIsCapturing()).

Wrap all CUDA calls with FLASHINFER_CHECK() to match the error-handling pattern established in the rest of the codebase.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/common/cudaUtils.h` around lines 276 - 291, Both
getSMVersion() and getMultiProcessorCount() silently ignore CUDA errors; wrap
every CUDA call inside these functions with FLASHINFER_CHECK() so failures are
propagated instead of returning bogus values. Specifically, in getSMVersion()
call FLASHINFER_CHECK(cudaGetDevice(&device)) and
FLASHINFER_CHECK(cudaDeviceGetAttribute(&sm_major,
cudaDevAttrComputeCapabilityMajor, device)) /
FLASHINFER_CHECK(cudaDeviceGetAttribute(&sm_minor,
cudaDevAttrComputeCapabilityMinor, device)); in getMultiProcessorCount() call
FLASHINFER_CHECK(cudaGetDevice(&device)) and
FLASHINFER_CHECK(cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount,
device)) so any CUDA error surfaces via the existing FLASHINFER_CHECK handling.

Comment on lines +180 to +203
struct SumNormalizePostprocess {
template <typename OutputT>
struct Params {
bool normTopkProb = true;

void set(routingCustom::Data const& data) { normTopkProb = data.mNormTopkProb; }
};

template <typename DataType, int K, typename ParamsT>
__forceinline__ __device__ static void apply(cg::thread_block_tile<WarpSize> const& warp,
DataType (&warpTopKScore)[K],
int32_t const (&/*warpTopKExpertIdx*/)[K],
int32_t laneIdx, int32_t topK,
ParamsT const& params) {
float sum = float{1.f};
if (params.normTopkProb) {
sum = static_cast<float>(laneIdx < topK ? warpTopKScore[laneIdx] : 0);
sum = cg::reduce(warp, sum, cg::plus<float>());
}
if (laneIdx < topK) {
warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum;
}
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Clamp the renormalization denominator for SigmoidRenorm.

SumNormalizePostprocess used to only see softmax outputs, but this PR now reuses it for SigmoidRenorm. With very negative or -inf logits, every selected sigmoid weight can be exactly 0, so the reduction at Line 197 returns 0 and Lines 199-200 produce NaNs.

🛠️ Suggested fix
     if (params.normTopkProb) {
       sum = static_cast<float>(laneIdx < topK ? warpTopKScore[laneIdx] : 0);
       sum = cg::reduce(warp, sum, cg::plus<float>());
     }
     if (laneIdx < topK) {
-      warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum;
+      float denom = params.normTopkProb ? fmaxf(sum, 1e-20f) : 1.f;
+      warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / denom;
     }
   }
 };

Also applies to: 456-462

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fused_moe/RoutingCustomPolicy.cuh` around lines 180
- 203, SumNormalizePostprocess::apply can produce a zero reduction when reused
for SigmoidRenorm, causing NaNs on division; fix it by clamping the computed
denominator before dividing: after computing sum (the cg::reduce result) ensure
sum = max(sum, epsilon) with a small float epsilon (e.g., 1e-6f) so
warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum is safe; update the apply
method in SumNormalizePostprocess (and the equivalent block around the other
occurrence noted) to perform this clamp when params.normTopkProb is true or
whenever the reduction result may be zero.

Comment on lines 327 to +343
struct Data : public DataBase {
tg::Dtype mDtypeExpW{tg::Dtype::Fp32};
tg::Dtype mDtypeElt{tg::Dtype::Bfloat16};
tg::Dtype mDtypeOutput{tg::Dtype::Fp32}; // OutputT: expert weights dtype (typically Bfloat16)
tg::Dtype mDtypeInput{tg::Dtype::Bfloat16}; // InputT: routing logits dtype (Bfloat16 or Fp32)

bool mDoSoftmaxBeforeTopK{false};
RoutingPreprocessType mPreprocessType{RoutingPreprocessType::None};
RoutingPostprocessType mPostprocessType{RoutingPostprocessType::Softmax};
bool mNormTopkProb{true}; // Default value is true for Qwen3 model
bool mApplySoftmaxAfterTopK{false};

// Optional: per-expert routing bias (used by SigmoidBias preprocess).
void const* mPtrRoutingBias{nullptr};
// Dtype of the routing bias buffer (Bfloat16 or Fp32). Used to read mPtrRoutingBias correctly.
tg::Dtype mDtypeBias{tg::Dtype::Bfloat16};
// Optional: scaling factor applied to final scores (used by ScaledSumNormalize postprocess).
float mRouteScale{1.0f};
// Optional: epsilon added to the sum before division to prevent division by zero.
// MiniMax2 uses 1e-20f; DeepSeek uses 0.0f (no epsilon).
float mSumEpsilon{0.0f};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Minor: Default dtype inconsistency with comment.

mDtypeOutput defaults to Fp32, but the comment says "OutputT: expert weights dtype (typically Bfloat16)". Other routing Data structs (routingDeepSeek::Data, routingLlama4::Data) default to Bfloat16.

Is Fp32 intentional for custom routing, or should this be Bfloat16 for consistency?

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fused_moe/RoutingKernel.h` around lines 327 - 343,
The Data struct sets mDtypeOutput to Fp32 while the comment and other routing
structs (routingDeepSeek::Data, routingLlama4::Data) expect Bfloat16; change the
default of Data::mDtypeOutput from tg::Dtype::Fp32 to tg::Dtype::Bfloat16 (or
explicitly document why Fp32 is required) so defaults are consistent—update the
initializer of mDtypeOutput in the Data struct and adjust any related
documentation/comments to match.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the MoE routing logic by introducing a policy-based design, which enhances modularity and flexibility. The changes consolidate duplicated code into shared components like runPostTopKPipeline and add support for more flexible data types for routing logits and biases. Several bug fixes related to expert index calculation and stream synchronization are also included. My review focuses on the new code structure and I have a couple of suggestions to improve readability.

// Number of blocks we can use in the cooperative kernel
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
// WAR: Reserve 8 SMs for overlapping kernels.
numBlocksCoop = smCount - 8;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The magic number 8 should be defined as a named constant to improve readability and maintainability. This would make it clearer why 8 SMs are being reserved. This pattern is also repeated in csrc/trtllm_fused_moe_routing_custom.cu and csrc/trtllm_fused_moe_routing_deepseek.cu.

For example, you could add static constexpr int kReservedSMsForOverlapping = 8; before this line and use the constant here.

Comment on lines +119 to +120
int mask =
idx == 0 ? 0xFFFFFF00 : idx == 1 ? 0xFFFF00FF : idx == 2 ? 0xFF00FFFF : 0x00FFFFFF;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ternary expression to create the mask is a bit verbose and hard to read. A bitwise operation would be more concise and clearer in its intent.

    int mask = ~(0xFF << (idx * 8));

@wenscarl
Copy link
Collaborator

Verified that this PR fixes #2732. Cc. @aleozlx for viz #2714

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants