Skip to content

Commit ba75b99

Browse files
committed
[TRTLLM-8958] and [TRTLLM-8960]: create ConfigurableMoE and support the TRTLLMGenFusedMoE as the backend in ConfigurableMoE
Signed-off-by: xxi <xxi@nvidia.com> modified: tensorrt_llm/_torch/model_config.py modified: tensorrt_llm/_torch/models/modeling_deepseekv3.py modified: tensorrt_llm/_torch/models/modeling_gpt_oss.py modified: tensorrt_llm/_torch/models/modeling_hunyuan_moe.py modified: tensorrt_llm/_torch/models/modeling_utils.py modified: tensorrt_llm/_torch/modules/fused_moe/communication/__init__.py modified: tensorrt_llm/_torch/modules/fused_moe/communication/allgather_reducescatter.py modified: tensorrt_llm/_torch/modules/fused_moe/communication/base.py modified: tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py modified: tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep.py modified: tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep_low_latency.py renamed: tensorrt_llm/_torch/modules/fused_moe/communication/mnnvl_throughput.py -> tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py renamed: tensorrt_llm/_torch/modules/fused_moe/communication/mnnvl_latency.py -> tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_two_sided.py new file: tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py modified: tensorrt_llm/_torch/modules/fused_moe/create_moe.py modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py modified: tensorrt_llm/_torch/modules/fused_moe/interface.py modified: tests/unittest/_torch/modeling/test_modeling_nemotron_h.py modified: tests/unittest/_torch/modules/test_fused_moe.py
1 parent 2f8bd6f commit ba75b99

23 files changed

+2210
-598
lines changed

tensorrt_llm/_torch/model_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,11 @@ def get_all_reduce_strategy(strategy: str = "AUTO"):
165165
self.allreduce_strategy = get_all_reduce_strategy(
166166
self.allreduce_strategy)
167167

168+
# Set default moe_max_num_tokens if not specified
169+
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
170+
if self.moe_max_num_tokens is None:
171+
self.moe_max_num_tokens = self.max_num_tokens * self.mapping.dp_size
172+
168173
@property
169174
def torch_dtype(self) -> torch.dtype:
170175
"""Get the torch dtype of the model."""

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from ..modules.attention import MLA
5656
from ..modules.decoder_layer import DecoderLayer
5757
from ..modules.embedding import Embedding
58-
from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod,
58+
from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod, MoE,
5959
MoEWeightLoadingMode, create_moe)
6060
from ..modules.fused_moe.fused_moe_wide_ep import WideEPMoE
6161
from ..modules.gated_mlp import GatedMLP
@@ -382,6 +382,21 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
382382
"gate_proj": "w1",
383383
})
384384
module.load_weights(weights=[module_weights])
385+
elif names[-1] == "backend" and isinstance(module, MoE):
386+
# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
387+
# Currently saved MoE weights don't include 'backend' in their names.
388+
# After MoE refactoring, ConfigurableMoE now has a backend submodule,
389+
# and weights loading is done in the backend, so module name includes '.backend'.
390+
# We need to use parent module name (without .backend) to match saved weight names.
391+
# After MoE refactoring is fully complete, all paths will follow this branch.
392+
parent_name = '.'.join(names[:-1])
393+
module_weights = filter_weights(parent_name, weights)
394+
module_weights = rename_moe_weight(module_weights, {
395+
"down_proj": "w2",
396+
"up_proj": "w3",
397+
"gate_proj": "w1",
398+
})
399+
module.load_weights(weights=[module_weights])
385400
elif names[-1] == "self_attn":
386401
continue
387402
elif names[-1] == "next_layer_layernorm":

tensorrt_llm/_torch/models/modeling_gpt_oss.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,18 @@ def load_hf_weights(self, weights: Dict):
657657
module_weights = {}
658658
for k, v in self.hf_params_map.items():
659659
name = name.replace(k, v)
660+
661+
# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
662+
# Currently saved MoE weights don't include 'backend' in their names.
663+
# After MoE refactoring, ConfigurableMoE now has a backend submodule,
664+
# and weights loading is done in the backend, so module name includes '.backend'.
665+
# We need to use parent module name (without .backend) to match saved weight names.
666+
# After MoE refactoring is fully complete, all paths will follow this branch.
667+
names = name.split('.')
668+
if names[-1] == "backend" and isinstance(module, MoE):
669+
# Backend is under experts module (ConfigurableMoE wrapper)
670+
name = '.'.join(names[:-1])
671+
660672
module_weights = filter_weights(name, weights)
661673

662674
if isinstance(module, MoE):

tensorrt_llm/_torch/models/modeling_hunyuan_moe.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
from ..modules.attention import Attention
1616
from ..modules.decoder_layer import DecoderLayer
1717
from ..modules.embedding import Embedding
18-
from ..modules.fused_moe import (CutlassFusedMoE, RenormalizeMoeRoutingMethod,
19-
VanillaMoE, create_moe)
18+
from ..modules.fused_moe import (CutlassFusedMoE, MoE,
19+
RenormalizeMoeRoutingMethod, VanillaMoE,
20+
create_moe)
2021
from ..modules.gated_mlp import GatedMLP
2122
from ..modules.linear import Linear, TensorParallelMode
2223
from ..modules.multi_stream_utils import maybe_execute_in_parallel
@@ -364,6 +365,17 @@ def filter_weights(prefix, weights: Dict):
364365
"lm_head"):
365366
continue
366367
names = name.split('.')
368+
369+
# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
370+
# Currently saved MoE weights don't include 'backend' in their names.
371+
# After MoE refactoring, ConfigurableMoE now has a backend submodule,
372+
# and weights loading is done in the backend, so module name includes '.backend'.
373+
# We need to use parent module name (without .backend) to match saved weight names.
374+
# After MoE refactoring is fully complete, all paths will follow this branch.
375+
if names[-1] == "backend" and isinstance(module, MoE):
376+
name = '.'.join(names[:-1])
377+
names = name.split('.')
378+
367379
if names[-1] in params_map:
368380
# model.layers.{idx}.mlp.shared_mlp.gate_up_proj or model.layers.{idx}.self_attn.qkv_proj
369381
module_weights = []

tensorrt_llm/_torch/models/modeling_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,17 @@ def load_single_module(name, module):
868868
return
869869

870870
names = name.split('.')
871+
872+
# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
873+
# Currently saved MoE weights don't include 'backend' in their names.
874+
# After MoE refactoring, ConfigurableMoE now has a backend submodule,
875+
# and weights loading is done in the backend, so module name includes '.backend'.
876+
# We need to use parent module name (without .backend) to match saved weight names.
877+
# After MoE refactoring is fully complete, all paths will follow this branch.
878+
if names[-1] == "backend" and isinstance(module, MoE):
879+
name = '.'.join(names[:-1])
880+
names = name.split('.')
881+
871882
# WAR: better solution is that llama has its own load_weights function.
872883
if names[-1] == 'next_layer_layernorm':
873884
return
@@ -968,6 +979,17 @@ def load_single_module(name, module):
968979
return
969980

970981
names = name.split('.')
982+
983+
# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
984+
# Currently saved MoE weights don't include 'backend' in their names.
985+
# After MoE refactoring, ConfigurableMoE now has a backend submodule,
986+
# and weights loading is done in the backend, so module name includes '.backend'.
987+
# We need to use parent module name (without .backend) to match saved weight names.
988+
# After MoE refactoring is fully complete, all paths will follow this branch.
989+
if names[-1] == "backend" and isinstance(module, MoE):
990+
name = '.'.join(names[:-1])
991+
names = name.split('.')
992+
971993
module_names_breakdown, module_name = names[:-1], names[-1]
972994

973995
if weight_mapper.does_require_special_handling(module_name):

tensorrt_llm/_torch/modules/fused_moe/communication/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
2121
Available Communication Methods:
2222
- AllGatherReduceScatter: Default fallback method, always available
23-
- MnnvlLatency: MNNVL-optimized communication for latency
24-
- MNNVLThroughput: MNNVL-optimized communication for throughput
23+
- NVLinkTwoSided: NVLINK-optimized communication for latency (formerly MNNVLLatency)
24+
- NVLinkOneSided: NVLINK-optimized communication for throughput (formerly MNNVLThroughput)
2525
- DeepEP: Deep Expert Parallelism with support for large batches
2626
- DeepEPLowLatency: Deep Expert Parallelism optimized for low latency
2727
@@ -34,16 +34,16 @@
3434
from .communication_factory import CommunicationFactory
3535
from .deep_ep import DeepEP
3636
from .deep_ep_low_latency import DeepEPLowLatency
37-
from .mnnvl_latency import MnnvlLatency
38-
from .mnnvl_throughput import MNNVLThroughput
37+
from .nvlink_one_sided import NVLinkOneSided
38+
from .nvlink_two_sided import NVLinkTwoSided
3939

4040
__all__ = [
4141
# Base classes and types
4242
"Communication",
4343
# Communication strategies
4444
"AllGatherReduceScatter",
45-
"MnnvlLatency",
46-
"MNNVLThroughput",
45+
"NVLinkTwoSided",
46+
"NVLinkOneSided",
4747
"DeepEP",
4848
"DeepEPLowLatency",
4949
# Factory

tensorrt_llm/_torch/modules/fused_moe/communication/allgather_reducescatter.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ def __init__(
4242
# Initialize dispatch state
4343
self._dispatch_state = {}
4444

45+
@staticmethod
46+
def is_platform_supported() -> bool:
47+
"""
48+
AllGather + ReduceScatter is always supported as the fallback strategy
49+
"""
50+
return True
51+
4552
def is_workload_feasible(self, all_rank_num_tokens: List[int], num_chunks: int) -> bool:
4653
"""
4754
Check if AllGather is feasible for the given workload at runtime.

tensorrt_llm/_torch/modules/fused_moe/communication/base.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,31 @@ def __init__(
5050
self.ep_size = mapping.moe_ep_size
5151
self.ep_rank = mapping.moe_ep_rank
5252

53+
# Check platform support and raise error if not supported
54+
if not self.is_platform_supported():
55+
raise RuntimeError(
56+
f"Communication strategy {self.__class__.__name__} "
57+
f"is not supported on this platform."
58+
)
59+
self._is_platform_supported = True
60+
61+
@staticmethod
62+
@abstractmethod
63+
def is_platform_supported() -> bool:
64+
"""
65+
Check if this communication strategy is supported on the current platform.
66+
67+
This method performs platform/hardware checks to determine if the strategy
68+
can be used on the current system.
69+
70+
Returns:
71+
True if platform is supported, False otherwise
72+
73+
Note: This is a static method that can be called before instantiation
74+
to check compatibility without creating an instance.
75+
"""
76+
raise NotImplementedError
77+
5378
@abstractmethod
5479
def is_workload_feasible(
5580
self,

0 commit comments

Comments
 (0)