Conversation
Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>
📝 WalkthroughWalkthroughThis PR introduces a policy-driven routing framework for MoE implementations, adding support for normalization of top-K probabilities and dynamic routing logits dtype tracking. It refactors routing kernel architectures across custom, DeepSeek, and Llama4 paths, consolidates post-topK pipeline logic, and updates launcher APIs to propagate new parameters through initialization and execution flows. Changes
Sequence DiagramsequenceDiagram
participant Launcher as MoE Launcher
participant Runner as Runner::run()
participant RoutingDispatch as Routing Dispatch
participant PolicyRouter as Policy-Based Router
participant PostTopK as Post-TopK Pipeline
participant Kernel as Routing Kernel
Launcher->>Runner: Call with routing config<br/>(norm_topk_prob, dtypeLogits)
activate Runner
Runner->>RoutingDispatch: Route based on method<br/>(Default, SigmoidRenorm, etc.)
activate RoutingDispatch
alt Has TopK Results & No Scores
RoutingDispatch->>PostTopK: Delegate to shared pipeline
activate PostTopK
PostTopK->>Kernel: Launch post-topK kernel<br/>(histogram/offsets)
else Has Scores
RoutingDispatch->>PolicyRouter: Select preprocess/postprocess<br/>policies based on method
activate PolicyRouter
PolicyRouter->>Kernel: LAUNCH_ROUTING_FOR_POLICY<br/>with composed policies
deactivate PolicyRouter
end
Kernel-->>PostTopK: Complete post-topK
deactivate PostTopK
Kernel-->>RoutingDispatch: Routing complete
deactivate RoutingDispatch
RoutingDispatch-->>Runner: Return routing results
deactivate Runner
Runner-->>Launcher: Provide expert indices<br/>& histogram
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip You can make CodeRabbit's review stricter and more nitpicky using the `assertive` profile, if that's what you prefer.Change the |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refactors the routing component of the trtllm-gen, focusing on improving the flexibility, extensibility, and performance of MoE (Mixture of Experts) configurations. It introduces a policy-based design for expert selection, adds support for new routing methods, and enhances data type flexibility. The changes aim to provide a more efficient and adaptable routing mechanism for various models. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh (1)
523-543:⚠️ Potential issue | 🔴 CriticalAlways decode
idxfrom packed Top-K input.In the packed-input branch,
idxis only assigned whenmPtrTopKWeights != nullptr. That makes the histogram use an uninitialized expert id whenever the routed path skips a separate weights buffer, and it also diverges from the coop helper below, which correctly reads packed ids regardless of whether weights are materialized.🐛 Suggested fix
auto loopBody = [&](int expandedIdx) { - PackedScoreIdx<OutputT> scoreIdx; int idx; if (params.mPtrTopKIds != nullptr) { idx = params.mPtrTopKIds[expandedIdx]; } else { - // If params.mPtrTopKIds != nullptr, we don't need to store the weights - if (params.mPtrTopKWeights != nullptr) { - scoreIdx = params.mPtrTopKPacked[expandedIdx]; - idx = scoreIdx.idx; - params.mPtrTopKWeights[expandedIdx] = static_cast<OutputT>(scoreIdx.score); - } + auto scoreIdx = params.mPtrTopKPacked[expandedIdx]; + idx = scoreIdx.idx; + if (params.mPtrTopKWeights != nullptr) { + params.mPtrTopKWeights[expandedIdx] = static_cast<OutputT>(scoreIdx.score); + } } // check whether this expert is local to our GPU at all and ignore if not auto localExpertIdx = idx - params.mLocalExpertsStartIdx;🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh` around lines 523 - 543, The lambda loopBody uses idx uninitialized when mPtrTopKIds==nullptr but params.mPtrTopKWeights==nullptr; always decode the packed top-k entry instead of conditioning it on mPtrTopKWeights: when params.mPtrTopKIds is nullptr, read PackedScoreIdx<OutputT> scoreIdx = params.mPtrTopKPacked[expandedIdx]; set idx = scoreIdx.idx unconditionally, and then only write params.mPtrTopKWeights[expandedIdx] = static_cast<OutputT>(scoreIdx.score) if params.mPtrTopKWeights != nullptr; this ensures loopBody, the localExpertIdx computation and atomicAdd(&smemExpertCount[idx], 1) use a valid decoded idx.
🧹 Nitpick comments (1)
csrc/trtllm_fused_moe_routing_deepseek.cu (1)
546-548: Avoid mutating caller-ownedDatapointer fields in-place.Setting
data.mPtrExpertCounts = nullptrmutates persistent input state and can break subsequent calls that reuse the sameDataobject.♻️ Suggested refactor
- if (data.mPtrPermutedIdxSize != nullptr) { + if (data.mPtrPermutedIdxSize != nullptr) { + Data launchData = data; bool const useSingleCluster = data.mNumTokens <= 1024; if (!useSingleCluster) { FLASHINFER_CHECK(data.mPtrExpertCounts != nullptr, "When `#tokens` is large, `mPtrExpertCounts` is a required input."); } else { - data.mPtrExpertCounts = - nullptr; // Set it to nullptr for single-cluster code path, as it won't be used + launchData.mPtrExpertCounts = nullptr; } @@ - data.mPdlOverlapWithNext = false; // Last kernel - launchClusterKernel(data, numThreadsHist, stream); + launchData.mPdlOverlapWithNext = false; // Last kernel + launchClusterKernel(launchData, numThreadsHist, stream);🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_fused_moe_routing_deepseek.cu` around lines 546 - 548, Do not mutate the caller-owned Data struct in-place by assigning data.mPtrExpertCounts = nullptr; instead, preserve the original by introducing a local variable (e.g., auto* localExpertCounts = data.mPtrExpertCounts) and use/override that local variable for the single-cluster code path where a nullptr is needed; update all subsequent uses in this scope to reference localExpertCounts (and ensure any ownership/freeing logic still respects the original Data ownership contract) so the passed-in Data remains unchanged for callers.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/trtllm_fused_moe_routing_custom.cu`:
- Around line 601-612: The coop eligibility check allows data.mNumExperts <=
1024 but the coop kernel (launched by launchCoopKernel) only supports up to
NumExperts576Experts (576); update the canUseCoop condition to use the actual
supported limit (e.g., replace the 1024 check with 576/NumExperts576Experts)
and/or add an explicit runtime guard just before calling launchCoopKernel that
verifies data.mNumExperts <= NumExperts576Experts to prevent attempting to
launch an unsupported coop kernel.
- Around line 445-479: The coop-launch branch in launchCoopKernel silently skips
launching any kernel when data.mNumExperts > NumExperts576Experts (only logging
FLASHINFER_WARN), leaving routing uninitialized; fix by either failing fast
(change the else to return an error/throw an exception or call FLASHINFER_ERROR)
or implement a fallback that launches a non-coop/multi-kernel routing path
(e.g., call LAUNCH_ROUTING_WITH_POLICIES with /*coopLaunch=*/false and the
non-coop kernel symbol such as routingIndicesKernel or the existing multi-kernel
dispatcher) so routing is always produced; update the else block in
launchCoopKernel accordingly and ensure any callers handle the error/exception
if you choose the fail-fast option.
In `@flashinfer/fused_moe/core.py`:
- Line 1347: The fake-op function signatures (_fake_trtllm_bf16_moe,
_fake_trtllm_fp8_per_tensor_scale_moe, _fake_trtllm_fp8_block_scale_moe,
_fake_trtllm_fp4_block_scale_moe, _fake_trtllm_mxint4_block_scale_moe) must be
updated so their parameter lists exactly match their real-op counterparts: add
the trailing norm_topk_prob boolean parameter and any other missing parameters
present in the real op schemas; adjust the function definitions and any internal
arg handling to accept and forward these parameters (preserve default values),
ensuring the fake-op signatures are identical to the real ops to avoid schema
drift during torch.compile/meta-tensor dispatch.
In `@include/flashinfer/trtllm/common/cudaUtils.h`:
- Around line 276-291: Both getSMVersion() and getMultiProcessorCount() silently
ignore CUDA errors; wrap every CUDA call inside these functions with
FLASHINFER_CHECK() so failures are propagated instead of returning bogus values.
Specifically, in getSMVersion() call FLASHINFER_CHECK(cudaGetDevice(&device))
and FLASHINFER_CHECK(cudaDeviceGetAttribute(&sm_major,
cudaDevAttrComputeCapabilityMajor, device)) /
FLASHINFER_CHECK(cudaDeviceGetAttribute(&sm_minor,
cudaDevAttrComputeCapabilityMinor, device)); in getMultiProcessorCount() call
FLASHINFER_CHECK(cudaGetDevice(&device)) and
FLASHINFER_CHECK(cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount,
device)) so any CUDA error surfaces via the existing FLASHINFER_CHECK handling.
In `@include/flashinfer/trtllm/fused_moe/RoutingCustomPolicy.cuh`:
- Around line 180-203: SumNormalizePostprocess::apply can produce a zero
reduction when reused for SigmoidRenorm, causing NaNs on division; fix it by
clamping the computed denominator before dividing: after computing sum (the
cg::reduce result) ensure sum = max(sum, epsilon) with a small float epsilon
(e.g., 1e-6f) so warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum is safe;
update the apply method in SumNormalizePostprocess (and the equivalent block
around the other occurrence noted) to perform this clamp when
params.normTopkProb is true or whenever the reduction result may be zero.
In `@include/flashinfer/trtllm/fused_moe/RoutingKernel.h`:
- Around line 327-343: The Data struct sets mDtypeOutput to Fp32 while the
comment and other routing structs (routingDeepSeek::Data, routingLlama4::Data)
expect Bfloat16; change the default of Data::mDtypeOutput from tg::Dtype::Fp32
to tg::Dtype::Bfloat16 (or explicitly document why Fp32 is required) so defaults
are consistent—update the initializer of mDtypeOutput in the Data struct and
adjust any related documentation/comments to match.
---
Outside diff comments:
In `@include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh`:
- Around line 523-543: The lambda loopBody uses idx uninitialized when
mPtrTopKIds==nullptr but params.mPtrTopKWeights==nullptr; always decode the
packed top-k entry instead of conditioning it on mPtrTopKWeights: when
params.mPtrTopKIds is nullptr, read PackedScoreIdx<OutputT> scoreIdx =
params.mPtrTopKPacked[expandedIdx]; set idx = scoreIdx.idx unconditionally, and
then only write params.mPtrTopKWeights[expandedIdx] =
static_cast<OutputT>(scoreIdx.score) if params.mPtrTopKWeights != nullptr; this
ensures loopBody, the localExpertIdx computation and
atomicAdd(&smemExpertCount[idx], 1) use a valid decoded idx.
---
Nitpick comments:
In `@csrc/trtllm_fused_moe_routing_deepseek.cu`:
- Around line 546-548: Do not mutate the caller-owned Data struct in-place by
assigning data.mPtrExpertCounts = nullptr; instead, preserve the original by
introducing a local variable (e.g., auto* localExpertCounts =
data.mPtrExpertCounts) and use/override that local variable for the
single-cluster code path where a nullptr is needed; update all subsequent uses
in this scope to reference localExpertCounts (and ensure any ownership/freeing
logic still respects the original Data ownership contract) so the passed-in Data
remains unchanged for callers.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 9d2337f3-d0f7-4f7c-90b9-208b0ea8826a
📒 Files selected for processing (18)
csrc/trtllm_fused_moe_kernel_launcher.cucsrc/trtllm_fused_moe_routing_common.cucsrc/trtllm_fused_moe_routing_custom.cucsrc/trtllm_fused_moe_routing_deepseek.cucsrc/trtllm_fused_moe_routing_llama4.cucsrc/trtllm_fused_moe_routing_renormalize.cucsrc/trtllm_fused_moe_runner.cuflashinfer/fused_moe/core.pyflashinfer/jit/fused_moe.pyinclude/flashinfer/trtllm/common/cudaUtils.hinclude/flashinfer/trtllm/fused_moe/DevKernel.hinclude/flashinfer/trtllm/fused_moe/RoutingCustomPolicy.cuhinclude/flashinfer/trtllm/fused_moe/RoutingDevKernel.hinclude/flashinfer/trtllm/fused_moe/RoutingKernel.cuhinclude/flashinfer/trtllm/fused_moe/RoutingKernel.hinclude/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuhinclude/flashinfer/trtllm/fused_moe/runner.htests/moe/test_trtllm_gen_fused_moe.py
💤 Files with no reviewable changes (1)
- csrc/trtllm_fused_moe_routing_renormalize.cu
| void launchCoopKernel(Data const& data, int numBlocksCoop, uint32_t numThreadsHist, void* stream) { | ||
| if (data.mNumExperts <= NumExperts128Experts) { | ||
| LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, | ||
| numBlocksCoop, numThreadsHist, /*smemSize=*/0, stream, | ||
| NoOpPreprocess, NoOpPostprocess, NumExperts128Experts, | ||
| NumTop8Experts); | ||
| } else if (data.mNumExperts <= NumExperts160Experts) { | ||
| LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, | ||
| numBlocksCoop, numThreadsHist, /*smemSize=*/0, stream, | ||
| NoOpPreprocess, NoOpPostprocess, NumExperts160Experts, | ||
| NumTop8Experts); | ||
| } else if (data.mNumExperts <= NumExperts256Experts) { | ||
| LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, | ||
| numBlocksCoop, numThreadsHist, /*smemSize=*/0, stream, | ||
| NoOpPreprocess, NoOpPostprocess, NumExperts256Experts, | ||
| NumTop8Experts); | ||
| } else if (data.mNumExperts <= NumExperts384Experts) { | ||
| LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, | ||
| numBlocksCoop, numThreadsHist, /*smemSize=*/0, stream, | ||
| NoOpPreprocess, NoOpPostprocess, NumExperts384Experts, | ||
| NumTop8Experts); | ||
| } else if (data.mNumExperts <= NumExperts512Experts) { | ||
| LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, | ||
| numBlocksCoop, numThreadsHist, /*smemSize=*/0, stream, | ||
| NoOpPreprocess, NoOpPostprocess, NumExperts512Experts, | ||
| NumTop8Experts); | ||
| } else if (data.mNumExperts <= NumExperts576Experts) { | ||
| LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, | ||
| numBlocksCoop, numThreadsHist, /*smemSize=*/0, stream, | ||
| NoOpPreprocess, NoOpPostprocess, NumExperts576Experts, | ||
| NumTop8Experts); | ||
| } else { | ||
| FLASHINFER_WARN("Coop kernel does not support numExperts > %d", NumExperts576Experts); | ||
| } | ||
| } |
There was a problem hiding this comment.
Silent failure when numExperts > 576.
When numExperts exceeds NumExperts576Experts, the function only prints a warning but does not launch any kernel, leaving the routing incomplete. This could cause downstream kernels to read uninitialized data.
Consider either:
- Returning an error/throwing to fail fast, or
- Falling back to a multi-kernel path that can handle larger expert counts
🐛 Suggested fix
} else if (data.mNumExperts <= NumExperts576Experts) {
LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel,
numBlocksCoop, numThreadsHist, /*smemSize=*/0, stream,
NoOpPreprocess, NoOpPostprocess, NumExperts576Experts,
NumTop8Experts);
} else {
- FLASHINFER_WARN("Coop kernel does not support numExperts > %d", NumExperts576Experts);
+ TVM_FFI_LOG_AND_THROW(NotImplementedError)
+ << "Coop kernel does not support numExperts > " << NumExperts576Experts
+ << ", got " << data.mNumExperts;
}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@csrc/trtllm_fused_moe_routing_custom.cu` around lines 445 - 479, The
coop-launch branch in launchCoopKernel silently skips launching any kernel when
data.mNumExperts > NumExperts576Experts (only logging FLASHINFER_WARN), leaving
routing uninitialized; fix by either failing fast (change the else to return an
error/throw an exception or call FLASHINFER_ERROR) or implement a fallback that
launches a non-coop/multi-kernel routing path (e.g., call
LAUNCH_ROUTING_WITH_POLICIES with /*coopLaunch=*/false and the non-coop kernel
symbol such as routingIndicesKernel or the existing multi-kernel dispatcher) so
routing is always produced; update the else block in launchCoopKernel
accordingly and ensure any callers handle the error/exception if you choose the
fail-fast option.
| static int const smMajor = tensorrt_llm::common::getSMVersion() / 10; | ||
| bool const canUseCoop = (smMajor >= 9) && (data.mNumExperts <= 1024) && | ||
| (data.mPtrPermutedIdxSize != nullptr); | ||
| bool useCoop = false; | ||
| int numBlocksCoop = 0; | ||
|
|
||
| if (canUseCoop) { | ||
| static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); | ||
| numBlocksCoop = smCount - 8; // Reserve 8 SMs for overlapping kernels | ||
| int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; | ||
| useCoop = (data.mNumTokens <= maxTokensCoop); | ||
| } |
There was a problem hiding this comment.
Mismatch between coop eligibility check and actual kernel support.
The canUseCoop check allows numExperts <= 1024, but launchCoopKernel only supports up to NumExperts576Experts (576). If numExperts is between 577 and 1024, useCoop could be true but the coop kernel would silently fail.
🐛 Suggested fix
if (canUseCoop) {
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
numBlocksCoop = smCount - 8; // Reserve 8 SMs for overlapping kernels
int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK;
- useCoop = (data.mNumTokens <= maxTokensCoop);
+ useCoop = (data.mNumTokens <= maxTokensCoop) && (data.mNumExperts <= NumExperts576Experts);
}Alternatively, update line 602 to use the actual limit:
- bool const canUseCoop = (smMajor >= 9) && (data.mNumExperts <= 1024) &&
+ bool const canUseCoop = (smMajor >= 9) && (data.mNumExperts <= NumExperts576Experts) &&📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| static int const smMajor = tensorrt_llm::common::getSMVersion() / 10; | |
| bool const canUseCoop = (smMajor >= 9) && (data.mNumExperts <= 1024) && | |
| (data.mPtrPermutedIdxSize != nullptr); | |
| bool useCoop = false; | |
| int numBlocksCoop = 0; | |
| if (canUseCoop) { | |
| static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); | |
| numBlocksCoop = smCount - 8; // Reserve 8 SMs for overlapping kernels | |
| int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; | |
| useCoop = (data.mNumTokens <= maxTokensCoop); | |
| } | |
| static int const smMajor = tensorrt_llm::common::getSMVersion() / 10; | |
| bool const canUseCoop = (smMajor >= 9) && (data.mNumExperts <= 1024) && | |
| (data.mPtrPermutedIdxSize != nullptr); | |
| bool useCoop = false; | |
| int numBlocksCoop = 0; | |
| if (canUseCoop) { | |
| static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); | |
| numBlocksCoop = smCount - 8; // Reserve 8 SMs for overlapping kernels | |
| int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; | |
| useCoop = (data.mNumTokens <= maxTokensCoop) && (data.mNumExperts <= NumExperts576Experts); | |
| } |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@csrc/trtllm_fused_moe_routing_custom.cu` around lines 601 - 612, The coop
eligibility check allows data.mNumExperts <= 1024 but the coop kernel (launched
by launchCoopKernel) only supports up to NumExperts576Experts (576); update the
canUseCoop condition to use the actual supported limit (e.g., replace the 1024
check with 576/NumExperts576Experts) and/or add an explicit runtime guard just
before calling launchCoopKernel that verifies data.mNumExperts <=
NumExperts576Experts to prevent attempting to launch an unsupported coop kernel.
| do_finalize: bool = True, | ||
| enable_pdl: Optional[bool] = None, | ||
| tune_max_num_tokens: int = 8192, | ||
| norm_topk_prob: bool = True, |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
python - <<'PY'
import ast
from pathlib import Path
path = Path("flashinfer/fused_moe/core.py")
module = ast.parse(path.read_text())
defs = {node.name: node for node in ast.walk(module) if isinstance(node, ast.FunctionDef)}
pairs = [
("trtllm_bf16_moe_op", "_fake_trtllm_bf16_moe"),
("trtllm_fp8_per_tensor_scale_moe_op", "_fake_trtllm_fp8_per_tensor_scale_moe"),
("trtllm_fp8_block_scale_moe_op", "_fake_trtllm_fp8_block_scale_moe"),
("trtllm_fp4_block_scale_moe_op", "_fake_trtllm_fp4_block_scale_moe"),
("trtllm_mxint4_block_scale_moe_op", "_fake_trtllm_mxint4_block_scale_moe"),
]
for real, fake in pairs:
rp = [arg.arg for arg in defs[real].args.args]
fp = [arg.arg for arg in defs[fake].args.args]
print(f"{real} vs {fake}: {'OK' if rp == fp else 'MISMATCH'}")
if rp != fp:
print(" real:", rp)
print(" fake:", fp)
PYRepository: flashinfer-ai/flashinfer
Length of output: 5080
Add norm_topk_prob (and missing parameters) to all fake-op signatures to match their real op counterparts.
The real custom ops now accept norm_topk_prob as a trailing parameter, but the corresponding fake ops (_fake_trtllm_bf16_moe, _fake_trtllm_fp8_per_tensor_scale_moe, _fake_trtllm_fp8_block_scale_moe, _fake_trtllm_fp4_block_scale_moe, _fake_trtllm_mxint4_block_scale_moe) retain their old signatures. Fake-op schemas must exactly match their real ops for torch.compile to correctly dispatch through meta tensors. Schema drift will cause failures when callers pass the new argument.
Also applies to: 1512-1512, 1671-1671, 1883-1883, 2106-2106
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fused_moe/core.py` at line 1347, The fake-op function signatures
(_fake_trtllm_bf16_moe, _fake_trtllm_fp8_per_tensor_scale_moe,
_fake_trtllm_fp8_block_scale_moe, _fake_trtllm_fp4_block_scale_moe,
_fake_trtllm_mxint4_block_scale_moe) must be updated so their parameter lists
exactly match their real-op counterparts: add the trailing norm_topk_prob
boolean parameter and any other missing parameters present in the real op
schemas; adjust the function definitions and any internal arg handling to accept
and forward these parameters (preserve default values), ensuring the fake-op
signatures are identical to the real ops to avoid schema drift during
torch.compile/meta-tensor dispatch.
| inline int getSMVersion() { | ||
| int device{-1}; | ||
| cudaGetDevice(&device); | ||
| int sm_major = 0; | ||
| int sm_minor = 0; | ||
| cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device); | ||
| cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device); | ||
| return sm_major * 10 + sm_minor; | ||
| } | ||
|
|
||
| inline int getMultiProcessorCount() { | ||
| int device{-1}; | ||
| cudaGetDevice(&device); | ||
| int count = 0; | ||
| cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device); | ||
| return count; |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Locate the cudaUtils.h file
find . -name "cudaUtils.h" -type fRepository: flashinfer-ai/flashinfer
Length of output: 173
🏁 Script executed:
# Check file contents around lines 276-291
if [ -f "include/flashinfer/trtllm/common/cudaUtils.h" ]; then
sed -n '276,295p' include/flashinfer/trtllm/common/cudaUtils.h
fiRepository: flashinfer-ai/flashinfer
Length of output: 606
🏁 Script executed:
# Search for FLASHINFER_CHECK macro definition in the repository
rg "define FLASHINFER_CHECK|FLASHINFER_CHECK\(" --type c --type cpp -A 3 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 3430
🏁 Script executed:
# Check what headers are included at the top of cudaUtils.h
if [ -f "include/flashinfer/trtllm/common/cudaUtils.h" ]; then
head -50 include/flashinfer/trtllm/common/cudaUtils.h
fiRepository: flashinfer-ai/flashinfer
Length of output: 1634
🏁 Script executed:
# Search for similar CUDA helper functions in flashinfer to see how they handle errors
rg "cudaGetDevice|cudaDeviceGetAttribute" --type c --type cpp -B 2 -A 2 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 5753
Surface CUDA query failures instead of silently returning bogus device info.
The functions getSMVersion() and getMultiProcessorCount() ignore CUDA status codes from cudaGetDevice() and cudaDeviceGetAttribute(). If the thread lacks an active device context, these helpers silently degrade to device = -1 or count = 0, causing any downstream dispatch logic to take incorrect paths. This pattern is inconsistent with error-checked versions of the same functions in include/flashinfer/trtllm/common.h (uses CUDACHECK) and csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h (uses check_cuda_error()). The FLASHINFER_CHECK macro is already available in this file and used elsewhere (e.g., in cudaStreamIsCapturing()).
Wrap all CUDA calls with FLASHINFER_CHECK() to match the error-handling pattern established in the rest of the codebase.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@include/flashinfer/trtllm/common/cudaUtils.h` around lines 276 - 291, Both
getSMVersion() and getMultiProcessorCount() silently ignore CUDA errors; wrap
every CUDA call inside these functions with FLASHINFER_CHECK() so failures are
propagated instead of returning bogus values. Specifically, in getSMVersion()
call FLASHINFER_CHECK(cudaGetDevice(&device)) and
FLASHINFER_CHECK(cudaDeviceGetAttribute(&sm_major,
cudaDevAttrComputeCapabilityMajor, device)) /
FLASHINFER_CHECK(cudaDeviceGetAttribute(&sm_minor,
cudaDevAttrComputeCapabilityMinor, device)); in getMultiProcessorCount() call
FLASHINFER_CHECK(cudaGetDevice(&device)) and
FLASHINFER_CHECK(cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount,
device)) so any CUDA error surfaces via the existing FLASHINFER_CHECK handling.
| struct SumNormalizePostprocess { | ||
| template <typename OutputT> | ||
| struct Params { | ||
| bool normTopkProb = true; | ||
|
|
||
| void set(routingCustom::Data const& data) { normTopkProb = data.mNormTopkProb; } | ||
| }; | ||
|
|
||
| template <typename DataType, int K, typename ParamsT> | ||
| __forceinline__ __device__ static void apply(cg::thread_block_tile<WarpSize> const& warp, | ||
| DataType (&warpTopKScore)[K], | ||
| int32_t const (&/*warpTopKExpertIdx*/)[K], | ||
| int32_t laneIdx, int32_t topK, | ||
| ParamsT const& params) { | ||
| float sum = float{1.f}; | ||
| if (params.normTopkProb) { | ||
| sum = static_cast<float>(laneIdx < topK ? warpTopKScore[laneIdx] : 0); | ||
| sum = cg::reduce(warp, sum, cg::plus<float>()); | ||
| } | ||
| if (laneIdx < topK) { | ||
| warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum; | ||
| } | ||
| } | ||
| }; |
There was a problem hiding this comment.
Clamp the renormalization denominator for SigmoidRenorm.
SumNormalizePostprocess used to only see softmax outputs, but this PR now reuses it for SigmoidRenorm. With very negative or -inf logits, every selected sigmoid weight can be exactly 0, so the reduction at Line 197 returns 0 and Lines 199-200 produce NaNs.
🛠️ Suggested fix
if (params.normTopkProb) {
sum = static_cast<float>(laneIdx < topK ? warpTopKScore[laneIdx] : 0);
sum = cg::reduce(warp, sum, cg::plus<float>());
}
if (laneIdx < topK) {
- warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum;
+ float denom = params.normTopkProb ? fmaxf(sum, 1e-20f) : 1.f;
+ warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / denom;
}
}
};Also applies to: 456-462
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@include/flashinfer/trtllm/fused_moe/RoutingCustomPolicy.cuh` around lines 180
- 203, SumNormalizePostprocess::apply can produce a zero reduction when reused
for SigmoidRenorm, causing NaNs on division; fix it by clamping the computed
denominator before dividing: after computing sum (the cg::reduce result) ensure
sum = max(sum, epsilon) with a small float epsilon (e.g., 1e-6f) so
warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum is safe; update the apply
method in SumNormalizePostprocess (and the equivalent block around the other
occurrence noted) to perform this clamp when params.normTopkProb is true or
whenever the reduction result may be zero.
| struct Data : public DataBase { | ||
| tg::Dtype mDtypeExpW{tg::Dtype::Fp32}; | ||
| tg::Dtype mDtypeElt{tg::Dtype::Bfloat16}; | ||
| tg::Dtype mDtypeOutput{tg::Dtype::Fp32}; // OutputT: expert weights dtype (typically Bfloat16) | ||
| tg::Dtype mDtypeInput{tg::Dtype::Bfloat16}; // InputT: routing logits dtype (Bfloat16 or Fp32) | ||
|
|
||
| bool mDoSoftmaxBeforeTopK{false}; | ||
| RoutingPreprocessType mPreprocessType{RoutingPreprocessType::None}; | ||
| RoutingPostprocessType mPostprocessType{RoutingPostprocessType::Softmax}; | ||
| bool mNormTopkProb{true}; // Default value is true for Qwen3 model | ||
| bool mApplySoftmaxAfterTopK{false}; | ||
|
|
||
| // Optional: per-expert routing bias (used by SigmoidBias preprocess). | ||
| void const* mPtrRoutingBias{nullptr}; | ||
| // Dtype of the routing bias buffer (Bfloat16 or Fp32). Used to read mPtrRoutingBias correctly. | ||
| tg::Dtype mDtypeBias{tg::Dtype::Bfloat16}; | ||
| // Optional: scaling factor applied to final scores (used by ScaledSumNormalize postprocess). | ||
| float mRouteScale{1.0f}; | ||
| // Optional: epsilon added to the sum before division to prevent division by zero. | ||
| // MiniMax2 uses 1e-20f; DeepSeek uses 0.0f (no epsilon). | ||
| float mSumEpsilon{0.0f}; |
There was a problem hiding this comment.
Minor: Default dtype inconsistency with comment.
mDtypeOutput defaults to Fp32, but the comment says "OutputT: expert weights dtype (typically Bfloat16)". Other routing Data structs (routingDeepSeek::Data, routingLlama4::Data) default to Bfloat16.
Is Fp32 intentional for custom routing, or should this be Bfloat16 for consistency?
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@include/flashinfer/trtllm/fused_moe/RoutingKernel.h` around lines 327 - 343,
The Data struct sets mDtypeOutput to Fp32 while the comment and other routing
structs (routingDeepSeek::Data, routingLlama4::Data) expect Bfloat16; change the
default of Data::mDtypeOutput from tg::Dtype::Fp32 to tg::Dtype::Bfloat16 (or
explicitly document why Fp32 is required) so defaults are consistent—update the
initializer of mDtypeOutput in the Data struct and adjust any related
documentation/comments to match.
There was a problem hiding this comment.
Code Review
This pull request refactors the MoE routing logic by introducing a policy-based design, which enhances modularity and flexibility. The changes consolidate duplicated code into shared components like runPostTopKPipeline and add support for more flexible data types for routing logits and biases. Several bug fixes related to expert index calculation and stream synchronization are also included. My review focuses on the new code structure and I have a couple of suggestions to improve readability.
| // Number of blocks we can use in the cooperative kernel | ||
| static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); | ||
| // WAR: Reserve 8 SMs for overlapping kernels. | ||
| numBlocksCoop = smCount - 8; |
There was a problem hiding this comment.
The magic number 8 should be defined as a named constant to improve readability and maintainability. This would make it clearer why 8 SMs are being reserved. This pattern is also repeated in csrc/trtllm_fused_moe_routing_custom.cu and csrc/trtllm_fused_moe_routing_deepseek.cu.
For example, you could add static constexpr int kReservedSMsForOverlapping = 8; before this line and use the constant here.
| int mask = | ||
| idx == 0 ? 0xFFFFFF00 : idx == 1 ? 0xFFFF00FF : idx == 2 ? 0xFF00FFFF : 0x00FFFFFF; |
📌 Description
Refactor the routing part of the trtllm-gen.
Execution Flow
Each routing method follows the same high-level pattern:
ctaIdxXyToBatchIdx,ctaIdxXyToMnLimit,numNonExitingCtasso the downstream batched GEMM knows how to partition work.Depending on token count, different code paths are selected:
When topK is pre-computed (mPtrTopKIds or mPtrTopKPacked), the first two paths skip topK,
and the coop/large paths skip the scores→topK kernel.
runPostTopKPipeline<DataType>()handles the permutation pipeline when topK is already computed(e.g., by DeepSeek's grouped main kernel). It converts any routing method's
DatatoroutingCustom::Dataand dispatches through the appropriate token-count path (single-block,single-cluster, coop, or multi-kernel). This avoids duplicating permutation code across methods.
Routing Methods and TierList Configuration
RoutingMethodType → Policy Mapping
Each
RoutingMethodType(defined inrunner.h) maps to a specific kernel path and policycombination. The
routingCustommethod uses a policy-based design where expert selection logic is injected as a compile-timeExpertSelectPolicytemplate parameter, providing zero runtime overhead and high extensibility.The default
TopKExpertSelect<PreprocessPolicy, PostprocessPolicy>wraps the traditional preprocess → topK → postprocess pattern, while users can write completely custom policies that bypass this pattern (e.g., lookup-table-based expert selection).Each policy owns its runtime data through a nested
Params<OutputT>struct. When a policy doesn't need extra data, itsParamsis empty and costs zero registers. This avoids paying for unused fields (e.g., a routing bias pointer) in policy combinations that don't need them.TierList and PolicyTraits
Each routing kernel is templated on
MaxNumExpertsandMaxNumTopExperts. These determineshared memory sizes, loop bounds, and the
__launch_bounds__thread count. To avoid compilingevery combination, each policy declares which
(MaxNumExperts, MaxTopK)pairs it supports viaa
PolicyTraitsspecialization:Dispatch:
dispatchTierPairsiterates theTierListfrom first to last, picking thefirst
Tier<E, K>wherenumExperts ≤ EANDtopK ≤ K. This means tighter tiersmust come first (sorted by E ascending, then K ascending within equal E).
Adding support for a new model: If a new model has a
(numExperts, topK)combinationnot covered by any existing tier, add a
Tier<E, K>to the appropriatePolicyTraitsspecialization. No other changes are needed — the dispatch macros are generic.
Available Policies
ParamsFieldsNoOpPreprocessBaseType = InputTSoftmaxPreprocessBaseType = floatSigmoidPreprocesssigmoid(score)(no bias);BaseType = floatSigmoidBiasPreprocessptrRoutingBias,dtypeBiassigmoid(score) + bias[expertIdx];BaseType = floatNoOpPostprocessSoftmaxPostprocessSumNormalizePostprocessnormTopkProbScaledSumNormalizePostprocessptrRoutingBias,dtypeBias,routeScale,sumEpsilon🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Release Notes
New Features
norm_topk_probparameter to control top-K probability normalization in MoE routing.SigmoidRenormrouting method type for enhanced routing flexibility.Improvements
Bug Fixes