Skip to content

Commit 15f34f5

Browse files
xxi-nvcodego7250
authored andcommitted
[TRTLLM-8959][feat] ConfigurableMoE support CUTLASS (NVIDIA#9772)
1 parent c18dd5e commit 15f34f5

File tree

9 files changed

+284
-209
lines changed

9 files changed

+284
-209
lines changed

tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_two_sided.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ def __init__(
6565
os.environ.get("TRTLLM_MOE_POST_QUANT_ALLTOALLV", "1") == "1"
6666
)
6767

68+
# Invalid token expert ID (default to -1), the kernels in TRTLLM-gen is hard-coded to support -1 only.
69+
# CutlassFusedMoE kernels support any invalid value.
70+
self.invalid_token_expert_id: int = -1
71+
6872
# Initialize NVLINK workspaces
6973
MnnvlMemory.initialize()
7074
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(mapping)
@@ -168,7 +172,7 @@ def dispatch(
168172
alltoall_info.recv_rank_count_cumsum,
169173
all_rank_max_num_tokens,
170174
top_k,
171-
self.num_slots,
175+
self.invalid_token_expert_id,
172176
self.ep_size,
173177
)
174178

tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py

Lines changed: 95 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,11 @@ def forward_impl(
402402
3. Execute MoE computation (single or multiple chunks)
403403
4. Handle output truncation and EPLB repeat
404404
"""
405+
# TODO: to clarify whether the output_dtype is needed.
406+
if isinstance(x, Fp4QuantizedTensor):
407+
assert output_dtype is not None
408+
else:
409+
output_dtype = x.dtype
405410
# ========== Step 1: Handle padding ==========
406411
if all_rank_num_tokens is None:
407412
all_rank_num_tokens = [x.shape[0]]
@@ -662,7 +667,7 @@ def _forward_chunk_impl(
662667
token_final_scales=token_final_scales,
663668
x_sf=x_sf,
664669
**self._get_backend_kwargs(
665-
router_logits, do_finalize, all_rank_num_tokens, output_dtype
670+
router_logits, do_finalize, all_rank_num_tokens, output_dtype, x
666671
),
667672
)
668673

@@ -875,12 +880,68 @@ def _is_using_nvlink_two_sided(self) -> bool:
875880
"""Check if using NVLinkTwoSided communication strategy"""
876881
return isinstance(self.comm, NVLinkTwoSided)
877882

883+
def _get_nvlink_onesided_moe_output(
884+
self,
885+
all_rank_num_tokens: Optional[List[int]],
886+
output_dtype: Optional[torch.dtype],
887+
) -> Optional[torch.Tensor]:
888+
"""
889+
Get workspace output buffer for NVLinkOneSided communication backend.
890+
891+
This method handles moe_output allocation for both CutlassFusedMoE and TRTLLMGenFusedMoE
892+
when using NVLinkOneSided communication strategy.
893+
894+
Args:
895+
all_rank_num_tokens: Token counts per rank
896+
output_dtype: Output data type
897+
898+
Returns:
899+
moe_output tensor if NVLinkOneSided is used and backend supports it, None otherwise
900+
"""
901+
if not isinstance(self.comm, NVLinkOneSided):
902+
return None
903+
904+
# Determine workspace dtype and whether backend supports workspace output
905+
workspace_dtype = output_dtype
906+
backend_supports_workspace = False
907+
908+
if isinstance(self.backend, TRTLLMGenFusedMoE):
909+
# TRTLLMGen specific configuration
910+
self.comm.invalid_token_expert_id = -1
911+
workspace_dtype = torch.bfloat16
912+
backend_supports_workspace = self.backend.has_w4a8_mxfp4_mxfp8
913+
elif isinstance(self.backend, CutlassFusedMoE):
914+
# Cutlass always supports workspace output with NVLinkOneSided
915+
backend_supports_workspace = True
916+
917+
if not backend_supports_workspace:
918+
# Ensure payload_in_workspace is False if backend doesn't support it
919+
self.comm.payload_in_workspace = False
920+
return None
921+
922+
# Calculate runtime max tokens per rank
923+
assert all_rank_num_tokens is not None, (
924+
"all_rank_num_tokens must be provided for NVLinkOneSided backend"
925+
)
926+
runtime_max_tokens_per_rank = max(all_rank_num_tokens)
927+
928+
# Get workspace-backed output tensor
929+
moe_output = self.comm.get_combine_payload_tensor_in_workspace(
930+
runtime_max_tokens_per_rank, self.hidden_size, workspace_dtype
931+
)
932+
933+
# Dynamically enable payload_in_workspace for this forward pass
934+
self.comm.payload_in_workspace = True
935+
936+
return moe_output
937+
878938
def _get_backend_kwargs(
879939
self,
880940
router_logits: Optional[torch.Tensor] = None,
881941
do_finalize: bool = True,
882942
all_rank_num_tokens: Optional[List[int]] = None,
883943
output_dtype: Optional[torch.dtype] = None,
944+
x: Optional[torch.Tensor] = None,
884945
) -> Dict:
885946
"""
886947
Get backend-specific keyword arguments for run_moe
@@ -905,6 +966,8 @@ def _get_backend_kwargs(
905966
router_logits: Router logits tensor (for TRTLLMGen backend)
906967
do_finalize: Whether to finalize output (for TRTLLMGen backend)
907968
all_rank_num_tokens: Token counts per rank (for TRTLLMGen backend moe_output)
969+
output_dtype: Output data type
970+
x: Input tensor (for calculating tuner_num_tokens in Cutlass)
908971
909972
Returns:
910973
Dict: Backend-specific keyword arguments
@@ -917,7 +980,33 @@ def _get_backend_kwargs(
917980

918981
# Cutlass-specific parameters
919982
if self.backend.__class__ == CutlassFusedMoE:
920-
pass
983+
# Determine if scaling factors are swizzled based on communication flow
984+
# In post-quant communication (quantize -> dispatch), scaling factors are not swizzled
985+
# In pre-quant communication (dispatch -> quantize), scaling factors are swizzled
986+
supports_post_quant = self.comm is not None and self.comm.supports_post_quant_dispatch()
987+
kwargs["is_sf_swizzled"] = not supports_post_quant
988+
kwargs["output_dtype"] = output_dtype
989+
990+
# Prepare additional information for profiling in case padding is applied when using alltoall.
991+
# Only the non-alltoall case is considered for profiling in the warmup phase.
992+
# Therefore, to get the correct tactics during the actual inference, the inputs to the tuner
993+
# should be the same as when not using alltoall.
994+
if self._is_using_alltoall():
995+
if all_rank_num_tokens is not None:
996+
kwargs["tuner_num_tokens"] = sum(all_rank_num_tokens)
997+
else:
998+
kwargs["tuner_num_tokens"] = (
999+
x.shape[0] * self.mapping.tp_size if x is not None else None
1000+
)
1001+
kwargs["tuner_top_k"] = self.routing_method.top_k
1002+
else:
1003+
kwargs["tuner_num_tokens"] = None
1004+
kwargs["tuner_top_k"] = None
1005+
1006+
# Get moe_output for NVLinkOneSided backend
1007+
kwargs["moe_output"] = self._get_nvlink_onesided_moe_output(
1008+
all_rank_num_tokens, output_dtype
1009+
)
9211010

9221011
# CuteDSL-specific parameters
9231012
elif self.backend.__class__ == CuteDslFusedMoE:
@@ -940,37 +1029,10 @@ def _get_backend_kwargs(
9401029
kwargs["router_logits"] = router_logits_arg
9411030
kwargs["do_finalize"] = do_finalize
9421031

943-
# moe_output: workspace output buffer for NVLINK one-sided backend
944-
# TRTLLMGenFusedMoE only supports workspace output for w4a8_mxfp4_mxfp8 quantization.
945-
moe_output = None
946-
if isinstance(self.comm, NVLinkOneSided):
947-
# Determine dtype for workspace tensor
948-
# TRTLLMGenFusedMoE always uses bfloat16, other backends use output_dtype
949-
workspace_dtype = output_dtype
950-
if isinstance(self.backend, TRTLLMGenFusedMoE):
951-
self.comm.invalid_token_expert_id = -1
952-
workspace_dtype = torch.bfloat16
953-
954-
# Check if backend supports workspace output for current quantization
955-
backend_supports_workspace = (
956-
isinstance(self.backend, TRTLLMGenFusedMoE)
957-
and self.backend.has_w4a8_mxfp4_mxfp8
958-
)
959-
if backend_supports_workspace:
960-
assert all_rank_num_tokens is not None, (
961-
"all_rank_num_tokens must be provided for NVLinkOneSided backend with workspace output"
962-
)
963-
runtime_max_tokens_per_rank = max(all_rank_num_tokens)
964-
965-
moe_output = self.comm.get_combine_payload_tensor_in_workspace(
966-
runtime_max_tokens_per_rank, self.hidden_size, workspace_dtype
967-
)
968-
# Dynamically enable payload_in_workspace for this forward pass
969-
self.comm.payload_in_workspace = True
970-
else:
971-
# Ensure payload_in_workspace is False for non-workspace output
972-
self.comm.payload_in_workspace = False
973-
kwargs["moe_output"] = moe_output
1032+
# Get moe_output for NVLinkOneSided backend
1033+
kwargs["moe_output"] = self._get_nvlink_onesided_moe_output(
1034+
all_rank_num_tokens, output_dtype
1035+
)
9741036

9751037
return kwargs
9761038

tensorrt_llm/_torch/modules/fused_moe/create_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def create_moe(
346346

347347
if ENABLE_CONFIGURABLE_MOE or moe_cls == CuteDslFusedMoE:
348348
# ConfigurableMoE only supports TRTLLMGenFusedMoE and CuteDslFusedMoE backends
349-
if moe_cls in (TRTLLMGenFusedMoE, CuteDslFusedMoE):
349+
if moe_cls in (TRTLLMGenFusedMoE, CuteDslFusedMoE, CutlassFusedMoE):
350350
return ConfigurableMoE(
351351
routing_method=routing_method,
352352
num_experts=num_experts,

0 commit comments

Comments
 (0)