Skip to content

Commit 24b05e8

Browse files
xxi-nvXingFei Xi
authored andcommitted
[TRTLLM-8958] and [TRTLLM-8960]: create ConfigurableMoE and support the TRTLLMGenFusedMoE as the backend in ConfigurableMoE
Signed-off-by: xxi <xxi@nvidia.com> modified: tensorrt_llm/_torch/model_config.py modified: tensorrt_llm/_torch/models/modeling_deepseekv3.py modified: tensorrt_llm/_torch/models/modeling_gpt_oss.py modified: tensorrt_llm/_torch/models/modeling_hunyuan_moe.py modified: tensorrt_llm/_torch/models/modeling_utils.py modified: tensorrt_llm/_torch/modules/fused_moe/communication/__init__.py modified: tensorrt_llm/_torch/modules/fused_moe/communication/base.py modified: tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py modified: tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep.py modified: tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep_low_latency.py renamed: tensorrt_llm/_torch/modules/fused_moe/communication/mnnvl_throughput.py -> tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py renamed: tensorrt_llm/_torch/modules/fused_moe/communication/mnnvl_latency.py -> tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_two_sided.py new file: tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py modified: tensorrt_llm/_torch/modules/fused_moe/create_moe.py modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py modified: tensorrt_llm/_torch/modules/fused_moe/interface.py modified: tests/unittest/_torch/modules/test_fused_moe.py
1 parent e484bec commit 24b05e8

21 files changed

+2063
-495
lines changed

tensorrt_llm/_torch/model_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,11 @@ def get_all_reduce_strategy(strategy: str = "AUTO"):
165165
self.allreduce_strategy = get_all_reduce_strategy(
166166
self.allreduce_strategy)
167167

168+
# Set default moe_max_num_tokens if not specified
169+
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
170+
if self.moe_max_num_tokens is None:
171+
self.moe_max_num_tokens = self.max_num_tokens * self.mapping.dp_size
172+
168173
@property
169174
def torch_dtype(self) -> torch.dtype:
170175
"""Get the torch dtype of the model."""

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from ..modules.attention import MLA
5656
from ..modules.decoder_layer import DecoderLayer
5757
from ..modules.embedding import Embedding
58-
from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod,
58+
from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod, MoE,
5959
MoEWeightLoadingMode, create_moe)
6060
from ..modules.fused_moe.fused_moe_wide_ep import WideEPMoE
6161
from ..modules.gated_mlp import GatedMLP
@@ -382,6 +382,21 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
382382
"gate_proj": "w1",
383383
})
384384
module.load_weights(weights=[module_weights])
385+
elif names[-1] == "backend" and isinstance(module, MoE):
386+
# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
387+
# Currently saved MoE weights don't include 'backend' in their names.
388+
# After MoE refactoring, ConfigurableMoE now has a backend submodule,
389+
# and weights loading is done in the backend, so module name includes '.backend'.
390+
# We need to use parent module name (without .backend) to match saved weight names.
391+
# After MoE refactoring is fully complete, all paths will follow this branch.
392+
parent_name = '.'.join(names[:-1])
393+
module_weights = filter_weights(parent_name, weights)
394+
module_weights = rename_moe_weight(module_weights, {
395+
"down_proj": "w2",
396+
"up_proj": "w3",
397+
"gate_proj": "w1",
398+
})
399+
module.load_weights(weights=[module_weights])
385400
elif names[-1] == "self_attn":
386401
continue
387402
elif names[-1] == "next_layer_layernorm":

tensorrt_llm/_torch/models/modeling_gpt_oss.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,18 @@ def load_hf_weights(self, weights: Dict):
657657
module_weights = {}
658658
for k, v in self.hf_params_map.items():
659659
name = name.replace(k, v)
660+
661+
# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
662+
# Currently saved MoE weights don't include 'backend' in their names.
663+
# After MoE refactoring, ConfigurableMoE now has a backend submodule,
664+
# and weights loading is done in the backend, so module name includes '.backend'.
665+
# We need to use parent module name (without .backend) to match saved weight names.
666+
# After MoE refactoring is fully complete, all paths will follow this branch.
667+
names = name.split('.')
668+
if names[-1] == "backend" and isinstance(module, MoE):
669+
# Backend is under experts module (ConfigurableMoE wrapper)
670+
name = '.'.join(names[:-1])
671+
660672
module_weights = filter_weights(name, weights)
661673

662674
if isinstance(module, MoE):

tensorrt_llm/_torch/models/modeling_hunyuan_moe.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
from ..modules.attention import Attention
1616
from ..modules.decoder_layer import DecoderLayer
1717
from ..modules.embedding import Embedding
18-
from ..modules.fused_moe import (CutlassFusedMoE, RenormalizeMoeRoutingMethod,
19-
VanillaMoE, create_moe)
18+
from ..modules.fused_moe import (CutlassFusedMoE, MoE,
19+
RenormalizeMoeRoutingMethod, VanillaMoE,
20+
create_moe)
2021
from ..modules.gated_mlp import GatedMLP
2122
from ..modules.linear import Linear, TensorParallelMode
2223
from ..modules.multi_stream_utils import maybe_execute_in_parallel
@@ -364,6 +365,17 @@ def filter_weights(prefix, weights: Dict):
364365
"lm_head"):
365366
continue
366367
names = name.split('.')
368+
369+
# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
370+
# Currently saved MoE weights don't include 'backend' in their names.
371+
# After MoE refactoring, ConfigurableMoE now has a backend submodule,
372+
# and weights loading is done in the backend, so module name includes '.backend'.
373+
# We need to use parent module name (without .backend) to match saved weight names.
374+
# After MoE refactoring is fully complete, all paths will follow this branch.
375+
if names[-1] == "backend" and isinstance(module, MoE):
376+
name = '.'.join(names[:-1])
377+
names = name.split('.')
378+
367379
if names[-1] in params_map:
368380
# model.layers.{idx}.mlp.shared_mlp.gate_up_proj or model.layers.{idx}.self_attn.qkv_proj
369381
module_weights = []

tensorrt_llm/_torch/models/modeling_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,17 @@ def load_single_module(name, module):
863863
return
864864

865865
names = name.split('.')
866+
867+
# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
868+
# Currently saved MoE weights don't include 'backend' in their names.
869+
# After MoE refactoring, ConfigurableMoE now has a backend submodule,
870+
# and weights loading is done in the backend, so module name includes '.backend'.
871+
# We need to use parent module name (without .backend) to match saved weight names.
872+
# After MoE refactoring is fully complete, all paths will follow this branch.
873+
if names[-1] == "backend" and isinstance(module, MoE):
874+
name = '.'.join(names[:-1])
875+
names = name.split('.')
876+
866877
# WAR: better solution is that llama has its own load_weights function.
867878
if names[-1] == 'next_layer_layernorm':
868879
return
@@ -956,6 +967,17 @@ def load_single_module(name, module):
956967
return
957968

958969
names = name.split('.')
970+
971+
# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
972+
# Currently saved MoE weights don't include 'backend' in their names.
973+
# After MoE refactoring, ConfigurableMoE now has a backend submodule,
974+
# and weights loading is done in the backend, so module name includes '.backend'.
975+
# We need to use parent module name (without .backend) to match saved weight names.
976+
# After MoE refactoring is fully complete, all paths will follow this branch.
977+
if names[-1] == "backend" and isinstance(module, MoE):
978+
name = '.'.join(names[:-1])
979+
names = name.split('.')
980+
959981
module_names_breakdown, module_name = names[:-1], names[-1]
960982

961983
if weight_mapper.does_require_special_handling(module_name):

tensorrt_llm/_torch/modules/fused_moe/communication/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
2121
Available Communication Methods:
2222
- AllGatherReduceScatter: Default fallback method, always available
23-
- MnnvlLatency: MNNVL-optimized communication for latency
24-
- MNNVLThroughput: MNNVL-optimized communication for throughput
23+
- NVLinkTwoSided: NVLINK-optimized communication for latency (formerly MNNVLLatency)
24+
- NVLinkOneSided: NVLINK-optimized communication for throughput (formerly MNNVLThroughput)
2525
- DeepEP: Deep Expert Parallelism with support for large batches
2626
- DeepEPLowLatency: Deep Expert Parallelism optimized for low latency
2727
@@ -34,16 +34,16 @@
3434
from .communication_factory import CommunicationFactory
3535
from .deep_ep import DeepEP
3636
from .deep_ep_low_latency import DeepEPLowLatency
37-
from .mnnvl_latency import MnnvlLatency
38-
from .mnnvl_throughput import MNNVLThroughput
37+
from .nvlink_one_sided import NVLinkOneSided
38+
from .nvlink_two_sided import NVLinkTwoSided
3939

4040
__all__ = [
4141
# Base classes and types
4242
"Communication",
4343
# Communication strategies
4444
"AllGatherReduceScatter",
45-
"MnnvlLatency",
46-
"MNNVLThroughput",
45+
"NVLinkTwoSided",
46+
"NVLinkOneSided",
4747
"DeepEP",
4848
"DeepEPLowLatency",
4949
# Factory

tensorrt_llm/_torch/modules/fused_moe/communication/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
self.mapping = mapping
5050
self.ep_size = mapping.moe_ep_size
5151
self.ep_rank = mapping.moe_ep_rank
52+
self._is_platform_supported = False
5253

5354
@abstractmethod
5455
def is_workload_feasible(

tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
from .base import Communication
3333
from .deep_ep import DeepEP
3434
from .deep_ep_low_latency import DeepEPLowLatency
35-
from .mnnvl_latency import MnnvlLatency
36-
from .mnnvl_throughput import MNNVLThroughput
35+
from .nvlink_one_sided import NVLinkOneSided
36+
from .nvlink_two_sided import NVLinkTwoSided
3737

3838

3939
def is_high_throughput() -> bool:
@@ -72,7 +72,7 @@ class CommunicationFactory:
7272
Factory for creating MoE communication methods
7373
7474
Selects the best communication method based on:
75-
- Hardware support (MNNVL, DeepEP)
75+
- Hardware support (NVLINK, DeepEP)
7676
- Configuration settings
7777
- Workload characteristics
7878
"""
@@ -85,16 +85,17 @@ def create_strategy(
8585
top_k: int,
8686
expert_size_per_partition: int,
8787
payload_in_workspace: bool = False,
88-
alltoall_result_do_sum: bool = False,
88+
alltoall_result_do_sum: bool = True,
8989
) -> Optional[Communication]:
9090
"""
9191
Create the best communication method for the given configuration
9292
9393
Selection priority:
94-
1. Force method (if specified via TRTLLM_FORCE_ALLTOALL_METHOD env)
95-
2. MNNVL (if hardware supports)
94+
1. Force method (if specified via TRTLLM_FORCE_COMM_METHOD env)
95+
2. NVLINK (if hardware supports)
9696
- Selects latency or throughput backend based on TRTLLM_MOE_ALLTOALL_BACKEND env
97-
- Default: "mnnvllatency", alternative: "mnnvlthroughput"
97+
- Default: "NVLinkTwoSided", legacy: "mnnvllatency"
98+
- Alternative: "NVLinkOneSided", legacy: "mnnvlthroughput"
9899
3. DeepEP / DeepEPLowLatency (if enabled and hardware supports)
99100
4. AllGather + ReduceScatter (fallback, always works)
100101
@@ -104,8 +105,8 @@ def create_strategy(
104105
num_slots: Total number of expert slots
105106
top_k: Number of experts per token
106107
expert_size_per_partition: Number of experts per partition (required for DeepEP)
107-
payload_in_workspace: If True, final_hidden_states is already in workspace (for MNNVLThroughput)
108-
alltoall_result_do_sum: If True, sum the alltoall results (for MnnvlLatency)
108+
payload_in_workspace: If True, final_hidden_states is already in workspace (for NVLinkOneSided)
109+
alltoall_result_do_sum: If True, sum the alltoall results (for NVLinkTwoSided)
109110
110111
Returns:
111112
The selected communication method, or None if attention does not use DP
@@ -134,19 +135,19 @@ def create_strategy(
134135
return AllGatherReduceScatter(mapping)
135136

136137
# Check if forced method is specified via environment variable
137-
force_method = os.environ.get("TRTLLM_FORCE_ALLTOALL_METHOD")
138+
force_method = os.environ.get("TRTLLM_FORCE_COMM_METHOD")
138139

139140
if force_method is not None:
140141
# Validate platform support for forced method
141142
method_upper = force_method.upper()
142-
if method_upper in ["MNNVLLATENCY", "MNNVLTHROUGHPUT"]:
143-
if not MnnvlLatency.is_platform_supported():
143+
if method_upper in ["NVLINK_TWO_SIDED", "NVLINK_ONE_SIDED"]:
144+
if not NVLinkTwoSided.is_platform_supported():
144145
raise RuntimeError(
145146
f"Forced method '{force_method}' is not supported on this platform. "
146-
"MNNVLLATENCY and MNNVLTHROUGHPUT require compatible hardware."
147+
"NVLINK two-sided and one-sided modes require compatible hardware."
147148
)
148149
elif method_upper in ["DEEPEP", "DEEPEPLOWLATENCY"]:
149-
if not DeepEP.is_platform_supported(mapping):
150+
if not DeepEP.is_platform_supported():
150151
raise RuntimeError(
151152
f"Forced method '{force_method}' is not supported on this platform. "
152153
"DeepEP requires compatible hardware and TRTLLM_CAN_USE_DEEP_EP=1."
@@ -163,19 +164,20 @@ def create_strategy(
163164
alltoall_result_do_sum,
164165
)
165166

166-
# Try MNNVL first (highest priority)
167-
if MnnvlLatency.is_platform_supported():
167+
# Try NVLINK first (highest priority)
168+
if NVLinkTwoSided.is_platform_supported():
169+
# TODO: update when we have a more clear heuristic.
168170
if is_high_throughput():
169-
# Currently, MNNVLThroughput shows better performance at all scenarios
170-
return MNNVLThroughput(
171+
# Currently, NVLinkOneSided shows better performance at all scenarios
172+
return NVLinkOneSided(
171173
mapping,
172-
num_experts,
174+
num_slots,
173175
top_k,
174176
max_num_tokens_per_rank=max_num_tokens,
175177
payload_in_workspace=payload_in_workspace,
176178
)
177179
else:
178-
return MnnvlLatency(
180+
return NVLinkTwoSided(
179181
mapping,
180182
num_experts,
181183
num_slots,
@@ -187,9 +189,7 @@ def create_strategy(
187189
# Try DeepEP
188190
if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "0") == "1":
189191
if weight_dtype == torch.bfloat16:
190-
if DeepEP.is_platform_supported(mapping) and is_deepep_feasible(
191-
mapping.moe_ep_size
192-
):
192+
if DeepEP.is_platform_supported() and is_deepep_feasible(mapping.moe_ep_size):
193193
return DeepEP(
194194
mapping,
195195
num_slots,
@@ -240,21 +240,21 @@ def _create_forced_method(
240240

241241
method = method.upper()
242242

243-
if method == "MNNVLLATENCY":
244-
return MnnvlLatency(
243+
if method in ["NVLINK_TWO_SIDED"]:
244+
return NVLinkTwoSided(
245245
mapping,
246246
num_experts,
247247
num_slots,
248248
top_k,
249249
use_low_precision_combine,
250250
alltoall_result_do_sum=alltoall_result_do_sum,
251251
)
252-
elif method == "MNNVLTHROUGHPUT":
253-
# MNNVLThroughput requires max_num_tokens_per_rank
252+
elif method in ["NVLINK_ONE_SIDED"]:
253+
# NVLinkOneSided requires max_num_tokens_per_rank
254254
# max_num_tokens is per-rank value (as passed from callers like cutlass)
255-
return MNNVLThroughput(
255+
return NVLinkOneSided(
256256
mapping,
257-
num_experts,
257+
num_slots,
258258
top_k,
259259
max_num_tokens_per_rank=max_num_tokens,
260260
payload_in_workspace=payload_in_workspace,

tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,11 @@ def __init__(
6666
self.deep_ep_buffer = buffer_pool.get_buffer(mapping)
6767
self.deep_ep_buffer.reserve(hidden_size, weight_dtype)
6868

69+
# Initialize platform support check result
70+
self._is_platform_supported = self.is_platform_supported()
71+
6972
@staticmethod
70-
def is_platform_supported(mapping: Mapping) -> bool:
73+
def is_platform_supported() -> bool:
7174
"""
7275
Check if DeepEP is supported on the current platform
7376
"""
@@ -94,7 +97,7 @@ def is_workload_feasible(self, all_rank_num_tokens: List[int], num_chunks: int)
9497
return False
9598
if self.weight_dtype != torch.bfloat16:
9699
return False
97-
return self.is_platform_supported(self.mapping)
100+
return self._is_platform_supported
98101

99102
def dispatch(
100103
self,

tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep_low_latency.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,11 @@ def __init__(
7979
self.deep_ep_buffer = buffer_pool.get_low_latency_buffer(mapping)
8080
self.deep_ep_buffer.reserve(self.deep_ep_max_num_tokens, hidden_size, num_slots)
8181

82+
# Initialize platform support check result
83+
self._is_platform_supported = self.is_platform_supported()
84+
8285
@staticmethod
83-
def is_platform_supported(mapping: Mapping) -> bool:
86+
def is_platform_supported() -> bool:
8487
"""
8588
Check if DeepEP Low Latency is supported on the current platform
8689
"""
@@ -113,7 +116,7 @@ def is_workload_feasible(self, all_rank_num_tokens: List[int], num_chunks: int)
113116
return False
114117
if self.weight_dtype != torch.bfloat16:
115118
return False
116-
return self.is_platform_supported(self.mapping)
119+
return self._is_platform_supported
117120

118121
def dispatch(
119122
self,

0 commit comments

Comments
 (0)