[None][feat] Add alltoall to trtllm-gen MoE backend.#8481
[None][feat] Add alltoall to trtllm-gen MoE backend.#8481bobboli merged 1 commit intoNVIDIA:mainfrom
Conversation
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
|
/bot run |
|
PR_Github #21816 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis change introduces all-to-all (alltoall) MoE support and quantization pre-communication handling to the TensorRT-LLM fused MoE module. It adds workspace allocation, a cached property for enabling alltoall logic, quantization routines for pre-communication, and extends the forward pass to route through new alltoallv preparation and consolidation paths while maintaining backward compatibility with existing post-quant allgather flows. Changes
Sequence Diagram(s)sequenceDiagram
participant FwdImpl as forward_impl()
participant PostQuant as Quantization Decision
participant Prep as Alltoall Prepare
participant Alltoallv as Alltoallv Ops
participant Final as Result Consolidation
FwdImpl->>PostQuant: Check post_quant_comm flag
alt post_quant_comm enabled
PostQuant->>PostQuant: _quantize_for_post_quant_comm(x)<br/>(fp8/nvfp4/mxfp4 modes)
alt enable_alltoall
PostQuant->>Prep: mnnvl_moe_alltoallv_prepare<br/>_without_allgather()
Prep->>Alltoallv: Execute alltoallv
Alltoallv->>Final: mnnvl_moe_alltoallv()<br/>consolidate results
Final->>Final: memset_expert_ids<br/>reshape token_final_scales<br/>token_selected_experts
Final->>FwdImpl: mnnvl_moe_alltoallv_combine()
else run_post_quant_allgather
PostQuant->>Final: Existing allgather path
Final->>FwdImpl: Return results
end
else post_quant_comm disabled
PostQuant->>FwdImpl: Standard processing
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes The changes introduce significant new control flow logic for alltoall MoE support, add multiple new workspace initialization paths, implement quantization pre-communication with multiple mode handling (fp8, nvfp4, mxfp4/mxf8), and interleave new branches throughout the forward pass alongside existing code paths. The review requires understanding the interplay between the new alltoallv preparation/consolidation flow, quantization logic, existing post-quant allgather paths, and workspace management, with attention to type casting, reshaping, and state management across different quantization modes. Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py (2)
57-74: Avoid mutable default for model_config.Using
ModelConfig()as a default binds a single instance at import time. UseNoneand instantiate inside.- model_config: ModelConfig = ModelConfig(), + model_config: Optional[ModelConfig] = None, @@ - super().__init__( + if model_config is None: + model_config = ModelConfig() + super().__init__( routing_method=routing_method, num_experts=num_experts, hidden_size=hidden_size, intermediate_size=intermediate_size, dtype=dtype, reduce_results=reduce_results, model_config=model_config,
252-253: Fix dtype assert for Union[torch.Tensor, Fp4QuantizedTensor].Current assert may break when
xisFp4QuantizedTensor.- assert x.dtype == torch.bfloat16 + if isinstance(x, torch.Tensor): + assert x.dtype == torch.bfloat16, "expected bf16 hidden states"
🧹 Nitpick comments (5)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py (5)
8-8: Prefer module-namespace import for _mnnvl_utils.Keeps API surface stable and aligns with “maintain module namespace” rule. As per coding guidelines.
-from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe +from tensorrt_llm import _mnnvl_utils as mnnvl_utilsFollow-up: replace
MnnvlMemory/MnnvlMoeusages withmnnvl_utils.MnnvlMemory/mnnvl_utils.MnnvlMoe.
112-120: Use self.mapping instead of model_config.mapping when allocating workspaces.Ensures consistency if the instance’s mapping diverges from the ctor argument.
- MnnvlMemory.initialize() - self.alltoall_workspace = MnnvlMoe.get_moe_workspaces( - model_config.mapping) - self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace( - model_config.mapping) + MnnvlMemory.initialize() + self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(self.mapping) + self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(self.mapping)If adopting module-namespace import, prefix with
mnnvl_utils.accordingly. As per coding guidelines.
125-133: Add a short docstring and verify supports_mnnvl() pre-init safety.Minor polish; also confirm calling
supports_mnnvl()beforeinitialize()is safe across platforms.@cached_property def enable_alltoall(self): + """Whether to enable alltoallv MoE comm (vs. allgather/reducescatter).""" mapping = self.mappingWould you confirm
MnnvlMemory.supports_mnnvl()does not require priorinitialize()?
199-240: Guard w4a16_mxfp4 padding and simplify unsupported-mode error.Prevent negative padding; make the error concise (TRY003 hint).
@@ - elif self.has_w4a16_mxfp4: - pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1] - x = torch.nn.functional.pad(x, (0, pad_size)) + elif self.has_w4a16_mxfp4: + pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1] + if pad_size < 0: + raise RuntimeError("w4a16_mxfp4: input width exceeds expected padded width") + x = torch.nn.functional.pad(x, (0, pad_size)) @@ - else: - raise ValueError( - f"unsupported quantization mode for post communication: {self.quant_config.quant_mode}" - ) + else: + raise NotImplementedError("Unsupported quant mode for post-communication")
605-616: Consider explicit combine options.If you need non-reduced outputs or low‑precision combine later, expose
do_reduce/use_low_precision_combinehere rather than relying on defaults.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py(15 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Use only spaces, no tabs; indent with 4 spaces.
Files:
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.
Files:
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).
Files:
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py (3)
tensorrt_llm/_mnnvl_utils.py (8)
MnnvlMemory(53-338)MnnvlMoe(352-624)get_moe_workspaces(360-376)get_moe_prepare_workspace(379-390)supports_mnnvl(332-338)mnnvl_moe_alltoallv_prepare_without_allgather(402-446)mnnvl_moe_alltoallv(531-592)mnnvl_moe_alltoallv_combine(595-624)tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (1)
enable_alltoall(317-320)tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py (1)
enable_alltoall(192-197)
🪛 Ruff (0.14.0)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
236-238: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py (1)
268-271: Gating looks right.Unified
post_quant_command exclusion of allgather when alltoall is enabled reads correctly.
|
PR_Github #21816 [ run ] completed with state |
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> Signed-off-by: yufeiwu-nv <230315618+yufeiwu-nv@users.noreply.github.com>
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Summary by CodeRabbit
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.