Support Sigmoid (sigmoid+topk) routing function#2869
Support Sigmoid (sigmoid+topk) routing function#2869EdalatiAli wants to merge 2 commits intoflashinfer-ai:mainfrom
Conversation
Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Mixture-of-Experts (MoE) routing capabilities by introducing new sigmoid-based routing functions and a robust, policy-driven architecture. The changes aim to provide greater flexibility and support for diverse MoE model designs, ensuring efficient and accurate expert selection under various activation and normalization schemes. The refactoring also streamlines the codebase by consolidating common routing utilities and optimizing kernel instantiations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request refactors MoE routing kernels by introducing a new policy-based dispatch system, consolidating common routing logic into trtllm_fused_moe_routing_common.cu and trtllm_fused_moe_routing_custom.cu, and removing the trtllm_fused_moe_routing_renormalize.cu file. It also adds new SigmoidRenorm and Sigmoid routing methods, and introduces a norm_topk_prob parameter to control top-K probability normalization. Review comments highlight concerns about the removal of specific dtype checks for routing_logits in trtllm_fused_moe_kernel_launcher.cu, suggesting that specific checks for each routing method should be maintained. There are also suggestions to use data_ptr() directly instead of static_cast<float*> for routing_logits to prevent type casting issues, and to use loadScalar for routing_bias to handle types correctly. Additionally, feedback points out that checks for routing_bias dimensions and shape should be conditional if not all methods use it, and that hardcoded expert tier values in trtllm_fused_moe_routing_deepseek.cu are not scalable and should be configurable. Minor comment inaccuracies regarding mUsePdl and mPdlOverlapWithNext in trtllm_fused_moe_routing_custom.cu were also noted.
I am having trouble creating individual review comments. Click here to see my feedback.
csrc/trtllm_fused_moe_kernel_launcher.cu (1799-1804)
The removal of the specific dtype checks for routing_logits based on RoutingMethodType is concerning. It's important to ensure that all routing methods now correctly handle both dl_float32 and dl_bfloat16 for routing_logits. If certain routing methods still require a specific dtype, this change could introduce errors. It's better to have specific checks for each routing method.
csrc/trtllm_fused_moe_kernel_launcher.cu (393-397)
The addition of mRoutingLogitsDtype and norm_topk_prob as arguments to the routing_runner.run function call is correct, but it's crucial to ensure that these parameters are correctly passed and handled in all subsequent calls to this function throughout the codebase. It's important to verify that the data types and values are consistent with the expected behavior of the routing kernel.
csrc/trtllm_fused_moe_kernel_launcher.cu (904-907)
The static_cast<float*> here is concerning. It's better to use data_ptr() directly and let the kernel handle the type. This is especially important if routing_logits is not always a float*. This could lead to incorrect memory access or type casting issues.
args->routing_logits = static_cast<float*>(routing_logits.value().data_ptr());
csrc/trtllm_fused_moe_routing_deepseek.cu (162-165)
Using loadScalar is better than static_cast<float>(params.mPtrRoutingBias[threadExpert]) as it handles the type correctly. This avoids potential type casting issues.
? static_cast<OutputT>(
loadScalar(params.mPtrRoutingBias, threadExpert, params.mDtypeBias))
: invalidScore;
csrc/trtllm_fused_moe_kernel_launcher.cu (1735-1736)
The type validation for routing_logits is now more general, allowing both dl_float32 and dl_bfloat16. This is good for flexibility, but it's crucial to ensure that the kernel implementation correctly handles both data types. Add a comment to explain why the type check is now more permissive.
TVM_FFI_ICHECK(routing_logits.dtype() == dl_float32 || routing_logits.dtype() == dl_bfloat16)
<< "FP8 per-tensor MoE: routing_logits must be float or bfloat16."; // Allow both float32 and bfloat16
csrc/trtllm_fused_moe_kernel_launcher.cu (929-930)
Using routing_logits.value().data_ptr() is better than static_cast<float*>(routing_logits.value().data_ptr()) as it avoids potential type casting issues. The kernel should handle the type appropriately.
args->routing_logits = routing_logits.value().data_ptr();
csrc/trtllm_fused_moe_kernel_launcher.cu (1840-1846)
The addition of checks for routing_bias dimensions and shape is good for robustness. However, it's important to ensure that these checks are consistent with the expected behavior of all routing methods that use routing_bias. If some methods don't use it, the check should be conditional.
csrc/trtllm_fused_moe_routing_deepseek.cu (67-77)
The use of hardcoded values for expert tiers is not scalable. These values should be configurable or derived from the input data to allow for flexibility in different MoE configurations. Consider using a function or a lookup table to determine the appropriate expert tier based on the number of experts.
csrc/trtllm_fused_moe_kernel_launcher.cu (881-892)
Adding RoutingMethodType::Sigmoid to this conditional block is correct for enabling the new routing method. However, it's important to ensure that the logic within this block is appropriate for all routing methods included, and that the comment accurately reflects the supported top_k values for all methods in the group. Consider updating the comment to be more general, or adding separate checks with distinct comments for each routing method if their requirements diverge.
TVM_FFI_ICHECK(args->top_k <= 10 && args->top_k > 0)
<< "Current routing kernel (no groups) only supports top_k<=10 && top_k>0.";
csrc/trtllm_fused_moe_routing_custom.cu (48-50)
When MaxNumExperts > 1024, the comment says the code caps the actual thread count at 1024. However, the code uses NumThreadsBlock which is assigned to MaxNumExperts <= 1024 ? MaxNumExperts : 1024. This means that the code is indeed capping the thread count at 1024. However, the code also says that each thread handles multiple experts. This needs to be verified to ensure that the thread is indeed handling multiple experts.
csrc/trtllm_fused_moe_routing_custom.cu (82-84)
Using params.mUsePdl is better than KernelParams::UsePdl as it uses the runtime value instead of the compile time value.
if (params.mUsePdl) {
csrc/trtllm_fused_moe_routing_custom.cu (311-312)
Using params.mUsePdl is better than KernelParams::UsePdl as it uses the runtime value instead of the compile time value.
if (params.mUsePdl) {
csrc/trtllm_fused_moe_routing_custom.cu (571)
The comment is not accurate. The code is checking for data.mUsePdl not KernelParams::UsePdl.
csrc/trtllm_fused_moe_routing_custom.cu (600)
The comment is not accurate. The code is checking for mutableData.mPdlOverlapWithNext not data.mPdlOverlapWithNext.
csrc/trtllm_fused_moe_routing_custom.cu (634)
The comment is not accurate. The code is checking for mutableData.mPdlOverlapWithNext not data.mPdlOverlapWithNext.
csrc/trtllm_fused_moe_routing_custom.cu (640)
The comment is not accurate. The code is checking for mutableData.mPdlOverlapWithNext not data.mPdlOverlapWithNext.
Signed-off-by: EdalatiAli <aliedalati@cohere.com>
5f2751e to
756c10f
Compare
📌 Description
Depends on #2803 .
This PR adds
RoutingMethodType.Sigmoidto support a routing function that applies sigmoid before topk (without renormalization) to be used by MoE layers that use this routing function.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes