[None][feat] Add PyTorch Runtime Support for MoE Weight Prefetching#6272
[None][feat] Add PyTorch Runtime Support for MoE Weight Prefetching#6272nvxuanyuc wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
📝 WalkthroughWalkthroughThis update introduces a Mixture of Experts (MoE) weight prefetching feature across the codebase. It adds new configuration options, propagates prefetching proxies through model and layer constructors, and implements a prefetch manager and proxy. Model forward passes are updated to trigger asynchronous weight prefetching, and new tests validate the feature. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant CLI/Config
participant Model
participant Layer
participant MoE
participant PrefetchManager
participant PrefetchProxy
User->>CLI/Config: Enable --use_moe_prefetch, set depth/stride
CLI/Config->>Model: Instantiate with MoeConfig (use_moe_prefetch, ...)
Model->>PrefetchManager: __moe_prefetch_init__()
PrefetchManager->>PrefetchProxy: Create proxies per MoE layer
Model->>Layer: Pass PrefetchProxy to MoE layers
Layer->>MoE: Pass PrefetchProxy to MoE module
loop For each forward pass
Model->>PrefetchManager: prefetch_weights(cur_stream)
PrefetchManager->>PrefetchProxy: Initiate async copy to device
Model->>Layer: Forward call
Layer->>MoE: Use prefetched weights if enabled
MoE->>PrefetchProxy: Start next layer prefetching
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~90 minutes Suggested labels
Suggested reviewers
Note ⚡️ Unit Test Generation - BetaCodeRabbit's unit test generation is now available in Beta! Automatically generate comprehensive unit tests for your code changes, ensuring better test coverage and catching edge cases you might miss. Our AI analyzes your code structure and creates tests that follow best practices and your project's testing patterns. Learn more here, or just try it under ✨ Finishing Touches. ✨ Finishing Touches
🧪 Generate unit tests
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
Documentation and Community
|
|
/bot run --disable-fail-fast |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (3)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py (1)
357-370: Consider refactoring to reduce code duplication.The stream synchronization and prefetch triggering logic is correctly implemented but duplicated across single and multi-chunk paths. Consider extracting this pattern into helper methods to improve maintainability.
+ def _sync_prefetch_stream(self): + """Synchronize with prefetch stream if prefetching is enabled.""" + if self.use_prefetch: + torch.cuda.current_stream().wait_stream( + self.prefetch_proxy.prefetch_stream) + + def _trigger_next_prefetch(self): + """Trigger prefetching for next layer if enabled.""" + if self.use_prefetch: + self.prefetch_proxy.start_next_layer_prefetching( + torch.cuda.current_stream()) + def forward( self, x: Union[torch.Tensor, Fp4QuantizedTensor], router_logits: torch.Tensor, do_finalize: bool = True, # used by other MoE backends output_dtype: Optional[torch.dtype] = None, all_rank_num_tokens: Optional[List[int]] = None, all_rank_max_num_tokens: Optional[int] = None, use_dp_padding: Optional[bool] = None, ) -> torch.Tensor: ... if num_chunks == 1: - - if self.use_prefetch: - torch.cuda.current_stream().wait_stream( - self.prefetch_proxy.prefetch_stream) - + self._sync_prefetch_stream() outputs = self.forward_chunk( x, router_logits, output_dtype, all_rank_num_tokens=all_rank_num_tokens_padded, use_dp_padding=use_dp_padding) - - if self.use_prefetch: - self.prefetch_proxy.start_next_layer_prefetching( - torch.cuda.current_stream()) - + self._trigger_next_prefetch() outputs = self.reducescatter_or_allreduce( outputs, all_rank_num_tokens=all_rank_num_tokens_padded, use_dp_padding=use_dp_padding) else: ... - if self.use_prefetch: - torch.cuda.current_stream().wait_stream( - self.prefetch_proxy.prefetch_stream) - + self._sync_prefetch_stream() ... outputs = torch.cat(outputs_list) - - if self.use_prefetch: - self.prefetch_proxy.start_next_layer_prefetching( - torch.cuda.current_stream()) - + self._trigger_next_prefetch()Also applies to: 394-396, 448-450
tensorrt_llm/llmapi/llm_args.py (1)
2068-2077: Model validator implementation is correct with sensible defaults.The validator properly creates a default
MoEPrefetchConfigwhen prefetching is enabled but no explicit config is provided. The default values (depth=2, stride=1) are reasonable starting points.One minor suggestion: Consider moving the import to the top of the file for better code organization.
+from .._torch.model_config import MoEPrefetchConfig # ... existing imports ... @model_validator(mode="after") def validate_moe_prefetch_config(self): - from .._torch.model_config import MoEPrefetchConfig if self.moe_config.use_moe_prefetch and self.moe_config.moe_prefetch_config is None:tensorrt_llm/_torch/modules/fused_moe/moe_prefetch_manager.py (1)
63-63: Fix typo in assertion message.- assert len(weights) == 2, "Experted two weight tensors per moe layer" + assert len(weights) == 2, "Expected two weight tensors per moe layer"
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (16)
examples/llm-api/quickstart_advanced.py(2 hunks)tensorrt_llm/_torch/model_config.py(2 hunks)tensorrt_llm/_torch/models/modeling_deepseekv3.py(8 hunks)tensorrt_llm/_torch/models/modeling_llama.py(8 hunks)tensorrt_llm/_torch/models/modeling_mixtral.py(7 hunks)tensorrt_llm/_torch/models/modeling_utils.py(2 hunks)tensorrt_llm/_torch/modules/fused_moe/__init__.py(2 hunks)tensorrt_llm/_torch/modules/fused_moe/create_moe.py(3 hunks)tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py(7 hunks)tensorrt_llm/_torch/modules/fused_moe/interface.py(3 hunks)tensorrt_llm/_torch/modules/fused_moe/moe_prefetch_manager.py(1 hunks)tensorrt_llm/_torch/modules/fused_moe/quantization.py(1 hunks)tensorrt_llm/_torch/pyexecutor/config.py(2 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py(4 hunks)tensorrt_llm/llmapi/llm_args.py(3 hunks)tests/integration/defs/accuracy/test_llm_api_pytorch.py(1 hunks)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/moe_prefetch_manager.py
80-80: Line too long (148 > 120)
(E501)
82-82: Line too long (139 > 120)
(E501)
88-88: Line too long (127 > 120)
(E501)
129-129: Line too long (151 > 120)
(E501)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/moe_prefetch_manager.py
80-80: Line too long (148 > 120)
(E501)
82-82: Line too long (139 > 120)
(E501)
88-88: Line too long (127 > 120)
(E501)
129-129: Line too long (151 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (37)
tensorrt_llm/_torch/modules/fused_moe/__init__.py (2)
10-10: LGTM: Clean import addition for prefetch infrastructure.The import of
MoEPrefetchManagerandMoEPrefetchProxycorrectly exposes the new prefetching infrastructure through the package's public API.
42-43: LGTM: Proper public API exposure.The addition of the new classes to
__all__follows standard Python packaging practices and maintains alphabetical ordering.tensorrt_llm/_torch/pyexecutor/config.py (2)
12-12: LGTM: Proper import update for MoE prefetch configuration.The import correctly includes
MoEPrefetchConfigalongside the existingMoeLoadBalancerConfig.
52-52: LGTM: Correct configuration field addition.The
moe_prefetch_configfield follows the established pattern for optional configuration parameters with proper type annotation and default value.tensorrt_llm/_torch/model_config.py (2)
60-64: LGTM: Well-designed MoE prefetch configuration.The
MoEPrefetchConfigdataclass has sensible defaults and clear parameter names:
prefetch_depth=2: Reasonable buffer depth for overlapping computation with memory transfersprefetch_stride=1: Conservative default for sparse prefetching
81-81: LGTM: Proper integration into ModelConfig.The
moe_prefetch_configfield correctly follows the established pattern for optional configuration parameters in the ModelConfig class.tensorrt_llm/_torch/modules/fused_moe/create_moe.py (3)
16-16: LGTM: Correct import for prefetch proxy.The import of
MoEPrefetchProxyis properly placed and enables the prefetching functionality integration.
65-65: LGTM: Proper function signature extension.The
moe_prefetch_proxyparameter is correctly typed as optional and follows the existing parameter pattern in thecreate_moefunction.
101-101: LGTM: Correct selective proxy passing.The prefetch proxy is correctly passed only to
CutlassFusedMoE, which aligns with the implementation scope mentioned in the PR objectives that prefetching is currently supported only in the CUTLASS backend.tensorrt_llm/_torch/modules/fused_moe/quantization.py (1)
75-98: LGTM: Well-implemented conditional weight allocation for prefetching.The implementation correctly handles two scenarios:
Prefetch enabled (lines 75-86):
- Allocates weights in CPU pinned memory for efficient host-to-device transfers
- Registers weights with the prefetch proxy for asynchronous management
- Uses
pin_memory()for optimal DMA transfer performancePrefetch disabled (lines 87-95):
- Falls back to standard GPU memory allocation
- Maintains existing behavior for backward compatibility
Both paths properly register parameters with the module, ensuring consistent API regardless of prefetch configuration.
tests/integration/defs/accuracy/test_llm_api_pytorch.py (1)
1263-1287: LGTM! Well-structured test for MoE weight prefetching feature.The test method effectively validates the MoE weight prefetching functionality across different quantization modes and parallelism configurations. The parameterization provides good coverage, and the use of high GPU memory fraction (0.9) appropriately simulates memory pressure scenarios where prefetching would be beneficial.
tensorrt_llm/_torch/models/modeling_utils.py (2)
24-25: Clean import expansion for MoE prefetching classes.The addition of
MoEPrefetchManagerandMoEPrefetchProxyto the imports is necessary for the new MoE weight prefetching functionality and follows the existing import pattern.
243-264: Well-designed initialization method for MoE prefetching infrastructure.The
__moe_prefetch_init__method provides clean and robust initialization logic for the MoE weight prefetching feature:
- Safe defaults: Initializes all attributes to safe default values before conditional setup
- Configuration-driven: Only enables prefetching when
moe_prefetch_configis present- Proper object lifecycle: Creates manager and proxy objects with appropriate parameters
- Scalable design: Pre-allocates proxy list for all layers, then populates only needed indices
The method integrates well with the existing model architecture and provides the foundation for MoE weight prefetching across all models that inherit from
DecoderModel.examples/llm-api/quickstart_advanced.py (1)
214-219: LGTM!The MoeConfig instantiation correctly includes the new MoE prefetch parameters.
tensorrt_llm/_torch/modules/fused_moe/interface.py (1)
45-45: Well-structured prefetch proxy integration.The optional parameter and initialization logic properly integrate the prefetch proxy while maintaining backward compatibility.
Also applies to: 85-91
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
44-44: Clean integration of MoE prefetch configuration.The changes properly thread the
MoEPrefetchConfigthrough the model loading pipeline, following the established pattern for configuration parameters.Also applies to: 296-296, 933-933, 948-948
tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py (3)
9-9: Import statement looks good.The import of
MoEPrefetchProxyis properly placed with other module imports and follows the relative import pattern.
60-72: Constructor changes are well-implemented.The optional
moe_prefetch_proxyparameter maintains backward compatibility with defaultNonevalue and is properly propagated to the parent class.
282-288: No action needed foruse_prefetchinitialization
Theuse_prefetchattribute is defined and initialized intensorrt_llm/_torch/modules/fused_moe/interface.py(lines 86 and 88). The conditional buffer-selection logic infused_moe_cutlass.pyis correct.tensorrt_llm/llmapi/llm_args.py (2)
110-137: MoE prefetch configuration fields are well-designed.The new fields maintain backward compatibility with sensible defaults. The validator correctly ensures positive values for depth and stride parameters, which is essential for the prefetching logic.
2151-2151: Configuration propagation is correct.The
moe_prefetch_configis properly passed frommoe_configtoPyTorchConfig, ensuring the prefetch settings reach the backend implementation.tensorrt_llm/_torch/models/modeling_deepseekv3.py (4)
415-446: LGTM! Proper integration of MoE prefetch proxy.The optional
moe_prefetch_proxyparameter is correctly added to the constructor and properly forwarded to thecreate_moefunction, following the established pattern for MoE configuration.
585-643: LGTM! Consistent propagation of prefetch proxy.The
moe_prefetch_proxyparameter is correctly propagated through the decoder layer to the MoE component when applicable (based on layer configuration).
1006-1030: LGTM! Proper initialization and distribution of prefetch proxies.The model correctly initializes the MoE prefetching infrastructure and distributes individual prefetch proxies to the appropriate decoder layers based on the MoE layer configuration.
1055-1058: LGTM! Correct prefetching trigger in forward pass.The forward method properly checks if prefetching is enabled and triggers weight prefetching on the current CUDA stream before processing decoder layers.
tensorrt_llm/_torch/models/modeling_mixtral.py (4)
25-57: LGTM! Consistent MoE prefetch integration.The
moe_prefetch_proxyparameter is properly integrated into MixtralMoE following the same pattern as other MoE models.
100-115: LGTM! Proper proxy propagation in decoder layer.The prefetch proxy is correctly passed through the decoder layer to the MoE component.
161-161: Verify MoE layer configuration for Mixtral.The
__moe_prefetch_init__call doesn't pass MoE layer frequency parameters unlike DeepseekV3 and Llama4. Please verify if this is intentional based on Mixtral's architecture where all layers might be MoE layers.
173-174: LGTM! Consistent prefetch implementation.The prefetch proxy distribution and forward pass prefetching logic are correctly implemented following the established pattern.
Also applies to: 201-204
tensorrt_llm/_torch/modules/fused_moe/moe_prefetch_manager.py (3)
62-85: LGTM! Robust weight registration with proper validation.The method correctly validates weight consistency across MoE layers and allocates device buffers on first registration. The assertion messages provide clear diagnostics for shape/dtype mismatches.
101-113: LGTM! Efficient asynchronous prefetching implementation.The method correctly synchronizes CUDA streams and uses non-blocking copies to overlap data transfer with computation, which is essential for the prefetching performance benefits.
138-179: LGTM! Well-designed proxy pattern for per-layer prefetching.The MoEPrefetchProxy class provides a clean interface for individual layers to manage their prefetching logic while delegating buffer management to the central manager. The asynchronous prefetching implementation with proper stream synchronization is correct.
tensorrt_llm/_torch/models/modeling_llama.py (5)
258-290: LGTM! Consistent MoE prefetch integration.The prefetch proxy parameter is properly added and forwarded to the MoE creation function.
357-419: LGTM! Proper handling of mixed MLP/MoE layers.The implementation correctly passes the prefetch proxy only to MoE layers, not MLP layers, which aligns with Llama4's interleaved architecture.
622-626: LGTM! Proper MoE prefetch initialization with layer configuration.The initialization correctly passes the MoE layer frequency and offset flag, which is essential for determining which layers participate in prefetching.
657-663: LGTM! Clever handling of decoder layer variants.The conditional logic ensures prefetch proxy is only passed to standard Llama4DecoderLayer instances, maintaining compatibility with the MinLatencyDecoderLayer variant.
690-693: LGTM! Consistent prefetch triggering.The forward method correctly triggers weight prefetching when enabled, following the same pattern as other models.
|
PR_Github #12612 [ run ] triggered by Bot |
|
PR_Github #12612 [ run ] completed with state |
|
/bot run --disable-fail-fast |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (3)
tensorrt_llm/_torch/modules/fused_moe/moe_prefetch_manager.py (3)
12-61: Consider adding input validation and improving documentation.The constructor handles complex layer index calculations and prefetch buffer management. A few suggestions:
- Add input validation for parameters to prevent runtime errors
- The layer index calculation logic is complex and could benefit from inline comments
- Consider extracting the layer index calculation into a separate private method for better readability
def __init__(self, num_hidden_layers: int, moe_layer_freq: int, add_one_to_layer_idx: bool, first_k_dense_replace: int, moe_prefetch_depth: int, moe_prefetch_stride: int): + # Input validation + if moe_prefetch_depth <= 0: + raise ValueError("moe_prefetch_depth must be positive") + if moe_prefetch_stride <= 0: + raise ValueError("moe_prefetch_stride must be positive") + if num_hidden_layers <= 0: + raise ValueError("num_hidden_layers must be positive") + # tp and ep support only self.prefetch_depth = moe_prefetch_depth # of buffers
62-84: Fix typo in assertion message.The method logic is sound with proper validation and lazy initialization. However, there's a typo in the assertion message.
- assert len(weights) == 2, "Experted two weight tensors per moe layer" + assert len(weights) == 2, "Expected two weight tensors per moe layer"
80-80: Consider breaking long lines for better readability.Several lines exceed the 120 character limit. While not critical, consider breaking them for better code style compliance.
- assert weights[0].shape == self.weight_shapes[ - "w3_w1_weight"], f"MoE w3_w1 Weight shapes mismatch on layer {layer_id}: {self.weight_shapes['w3_w1_weight']} != {weights[0].shape}" + assert weights[0].shape == self.weight_shapes["w3_w1_weight"], ( + f"MoE w3_w1 Weight shapes mismatch on layer {layer_id}: " + f"{self.weight_shapes['w3_w1_weight']} != {weights[0].shape}") - assert weights[1].shape == self.weight_shapes[ - "w2_weight"], f"MoE w2 Weight shapes mismatch on layer {layer_id}: {self.weight_shapes['w2_weight']} != {weights[1].shape}" + assert weights[1].shape == self.weight_shapes["w2_weight"], ( + f"MoE w2 Weight shapes mismatch on layer {layer_id}: " + f"{self.weight_shapes['w2_weight']} != {weights[1].shape}")Also applies to: 82-82, 88-88, 129-129
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (16)
examples/llm-api/quickstart_advanced.py(2 hunks)tensorrt_llm/_torch/model_config.py(2 hunks)tensorrt_llm/_torch/models/modeling_deepseekv3.py(8 hunks)tensorrt_llm/_torch/models/modeling_llama.py(8 hunks)tensorrt_llm/_torch/models/modeling_mixtral.py(7 hunks)tensorrt_llm/_torch/models/modeling_utils.py(2 hunks)tensorrt_llm/_torch/modules/fused_moe/__init__.py(2 hunks)tensorrt_llm/_torch/modules/fused_moe/create_moe.py(3 hunks)tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py(7 hunks)tensorrt_llm/_torch/modules/fused_moe/interface.py(3 hunks)tensorrt_llm/_torch/modules/fused_moe/moe_prefetch_manager.py(1 hunks)tensorrt_llm/_torch/modules/fused_moe/quantization.py(1 hunks)tensorrt_llm/_torch/pyexecutor/config.py(2 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py(4 hunks)tensorrt_llm/llmapi/llm_args.py(3 hunks)tests/integration/defs/accuracy/test_llm_api_pytorch.py(1 hunks)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/moe_prefetch_manager.py
80-80: Line too long (148 > 120)
(E501)
82-82: Line too long (139 > 120)
(E501)
88-88: Line too long (127 > 120)
(E501)
129-129: Line too long (151 > 120)
(E501)
✅ Files skipped from review due to trivial changes (1)
- tensorrt_llm/_torch/model_config.py
🚧 Files skipped from review as they are similar to previous changes (14)
- tensorrt_llm/_torch/modules/fused_moe/init.py
- tensorrt_llm/_torch/pyexecutor/config.py
- tensorrt_llm/_torch/modules/fused_moe/create_moe.py
- tensorrt_llm/_torch/modules/fused_moe/interface.py
- tensorrt_llm/_torch/pyexecutor/model_engine.py
- examples/llm-api/quickstart_advanced.py
- tensorrt_llm/_torch/modules/fused_moe/quantization.py
- tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py
- tensorrt_llm/_torch/models/modeling_utils.py
- tests/integration/defs/accuracy/test_llm_api_pytorch.py
- tensorrt_llm/_torch/models/modeling_deepseekv3.py
- tensorrt_llm/_torch/models/modeling_mixtral.py
- tensorrt_llm/llmapi/llm_args.py
- tensorrt_llm/_torch/models/modeling_llama.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/moe_prefetch_manager.py
80-80: Line too long (148 > 120)
(E501)
82-82: Line too long (139 > 120)
(E501)
88-88: Line too long (127 > 120)
(E501)
129-129: Line too long (151 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (7)
tensorrt_llm/_torch/modules/fused_moe/moe_prefetch_manager.py (7)
1-4: LGTM! Clean and appropriate imports.The imports are minimal and directly related to the functionality implemented in this module.
86-99: LGTM! Proper device buffer allocation.The method correctly allocates GPU memory buffers with the appropriate dtype and device placement.
101-112: LGTM! Proper CUDA stream management and asynchronous copying.The method correctly implements:
- Stream synchronization with
wait_stream- Non-blocking memory copies
- Proper tensor dtype handling with
.view()The asynchronous prefetching logic follows CUDA best practices.
114-135: LGTM! Well-implemented getter methods with proper validation.The getter methods provide clean interfaces with appropriate assertion checks to ensure layers are properly registered before access.
144-155: LGTM! Clean proxy design with proper initialization.The constructor properly initializes the proxy with necessary references and lazy-loads the prefetch stream. The design follows good separation of concerns.
157-165: LGTM! Proper weight registration with clean delegation.The method correctly registers weights with the prefetch manager while maintaining necessary local state for the proxy.
167-178: LGTM! Correct implementation of asynchronous prefetching.The method properly handles:
- Null checks for when no more prefetching is needed
- CUDA stream synchronization
- Non-blocking memory copies with correct dtype handling
The prefetching logic is well-implemented.
|
PR_Github #12730 [ run ] triggered by Bot |
|
PR_Github #12730 [ run ] completed with state |
|
/bot run --disable-fail-fast |
There was a problem hiding this comment.
Actionable comments posted: 7
🧹 Nitpick comments (3)
tensorrt_llm/_torch/modules/fused_moe/moe_prefetch_manager.py (3)
77-82: Break long lines to improve readability.These lines exceed 120 characters. Consider breaking them for better readability.
- assert self.weight_dtype == weights[ - 0].dtype, f"MoE Dtype mismatch on layer {layer_id}: {self.weight_dtype} != {weights[0].dtype}" - assert weights[0].shape == self.weight_shapes[ - "w3_w1_weight"], f"MoE w3_w1 Weight shapes mismatch on layer {layer_id}: {self.weight_shapes['w3_w1_weight']} != {weights[0].shape}" - assert weights[1].shape == self.weight_shapes[ - "w2_weight"], f"MoE w2 Weight shapes mismatch on layer {layer_id}: {self.weight_shapes['w2_weight']} != {weights[1].shape}" + assert self.weight_dtype == weights[0].dtype, ( + f"MoE Dtype mismatch on layer {layer_id}: " + f"{self.weight_dtype} != {weights[0].dtype}" + ) + assert weights[0].shape == self.weight_shapes["w3_w1_weight"], ( + f"MoE w3_w1 Weight shapes mismatch on layer {layer_id}: " + f"{self.weight_shapes['w3_w1_weight']} != {weights[0].shape}" + ) + assert weights[1].shape == self.weight_shapes["w2_weight"], ( + f"MoE w2 Weight shapes mismatch on layer {layer_id}: " + f"{self.weight_shapes['w2_weight']} != {weights[1].shape}" + )
88-88: Break long assertion message for readability.The assertion message exceeds 120 characters.
- assert self.weight_dtype is not None, "MoE Prefetched Weight dtype must be set before allocating device weight buffers" + assert self.weight_dtype is not None, ( + "MoE Prefetched Weight dtype must be set before allocating device weight buffers" + )
179-179: Add newline at end of file.Python files should end with a newline character.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (16)
examples/llm-api/quickstart_advanced.py(2 hunks)tensorrt_llm/_torch/model_config.py(2 hunks)tensorrt_llm/_torch/models/modeling_deepseekv3.py(8 hunks)tensorrt_llm/_torch/models/modeling_llama.py(8 hunks)tensorrt_llm/_torch/models/modeling_mixtral.py(7 hunks)tensorrt_llm/_torch/models/modeling_utils.py(2 hunks)tensorrt_llm/_torch/modules/fused_moe/__init__.py(2 hunks)tensorrt_llm/_torch/modules/fused_moe/create_moe.py(3 hunks)tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py(7 hunks)tensorrt_llm/_torch/modules/fused_moe/interface.py(3 hunks)tensorrt_llm/_torch/modules/fused_moe/moe_prefetch_manager.py(1 hunks)tensorrt_llm/_torch/modules/fused_moe/quantization.py(1 hunks)tensorrt_llm/_torch/pyexecutor/config.py(2 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py(4 hunks)tensorrt_llm/llmapi/llm_args.py(3 hunks)tests/integration/defs/accuracy/test_llm_api_pytorch.py(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- tensorrt_llm/_torch/pyexecutor/config.py
🚧 Files skipped from review as they are similar to previous changes (14)
- tensorrt_llm/_torch/modules/fused_moe/create_moe.py
- tensorrt_llm/_torch/modules/fused_moe/init.py
- tensorrt_llm/_torch/model_config.py
- tensorrt_llm/_torch/models/modeling_utils.py
- tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py
- examples/llm-api/quickstart_advanced.py
- tensorrt_llm/_torch/modules/fused_moe/interface.py
- tensorrt_llm/_torch/pyexecutor/model_engine.py
- tests/integration/defs/accuracy/test_llm_api_pytorch.py
- tensorrt_llm/_torch/modules/fused_moe/quantization.py
- tensorrt_llm/_torch/models/modeling_mixtral.py
- tensorrt_llm/llmapi/llm_args.py
- tensorrt_llm/_torch/models/modeling_deepseekv3.py
- tensorrt_llm/_torch/models/modeling_llama.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.{cpp,h,hpp,cc,cxx,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.{cpp,h,hpp,cc,cxx,cu,py}: Use only spaces for indentation. Do not use tabs. Indent 4 spaces at a time.
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
tensorrt_llm/_torch/modules/fused_moe/moe_prefetch_manager.py
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.
Files:
tensorrt_llm/_torch/modules/fused_moe/moe_prefetch_manager.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/moe_prefetch_manager.py
80-80: Line too long (148 > 120)
(E501)
82-82: Line too long (139 > 120)
(E501)
88-88: Line too long (127 > 120)
(E501)
129-129: Line too long (151 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
|
PR_Github #12756 [ run ] triggered by Bot |
|
PR_Github #12756 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #12808 [ run ] triggered by Bot |
|
PR_Github #12808 [ run ] completed with state |
|
PR_Github #22851 [ run ] completed with state |
f5bba19 to
f25096f
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #23088 [ run ] triggered by Bot. Commit: |
|
PR_Github #23088 [ run ] completed with state |
|
/bot run |
|
PR_Github #23231 [ run ] triggered by Bot. Commit: |
|
PR_Github #23231 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #23418 [ run ] triggered by Bot. Commit: |
|
PR_Github #23418 [ run ] completed with state |
|
/bot run |
|
PR_Github #23428 [ run ] triggered by Bot. Commit: |
|
PR_Github #23428 [ run ] completed with state |
|
/bot run |
|
PR_Github #23430 [ run ] triggered by Bot. Commit: |
|
PR_Github #23430 [ run ] completed with state |
|
/bot run |
Signed-off-by: Xuanyu Chen <xuanyuc@nvidia.com>
Signed-off-by: Xuanyu Chen <xuanyuc@nvidia.com>
Signed-off-by: Xuanyu Chen <xuanyuc@nvidia.com>
Signed-off-by: Xuanyu Chen <xuanyuc@nvidia.com>
|
/bot run --disable-fail-fast |
|
PR_Github #23669 [ run ] triggered by Bot. Commit: |
|
PR_Github #23669 [ run ] completed with state |
Description
Overview
This PR introduces PyTorch runtime support for MoE model weight prefetching, primarily targeting:
Currently supported models include DeepSeek, LLaMa 4, and Mixtral, with minimal effort required to extend to other MoE models.
Background
Previously, the full set of MoE layer weights had to be stored in device memory, placing significant pressure on available memory — especially for models with hundreds of experts per layer. This PR introduces an alternative: keeping only a limited number of MoE layers in device memory, and dynamically prefetching weights from host memory on demand.
When performance is a critical concern, prefetching is recommended only during the prefill phase, where memory copy latency can be effectively overlapped with computation. This makes the approach well-suited for disaggregated serving scenarios.
Implementation Details
Core Concepts
Prefetching Mechanism
Design Architecture
This separation improves modularity and enables easy future extensions (e.g. ratio-based or priority-based prefetching strategies).
MoE Backend Support
The implementation is designed to be extensible with minimal integration overhead. For reference, Mixtral 8x7B is a simple and direct example.
Brief Performance Results
Test Coverage
The prefetching logic is integrated within the fused MoE module, and model-side integration only requires setting the appropriate prefetching config values. To avoid redundancy, we validate correctness using only the DeepSeek-V3-Lite integration test in the current pipeline.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.
Summary by CodeRabbit
Summary by CodeRabbit
New Features
Tests
Chores