@@ -402,6 +402,11 @@ def forward_impl(
402402 3. Execute MoE computation (single or multiple chunks)
403403 4. Handle output truncation and EPLB repeat
404404 """
405+ # TODO: to clarify whether the output_dtype is needed.
406+ if isinstance (x , Fp4QuantizedTensor ):
407+ assert output_dtype is not None
408+ else :
409+ output_dtype = x .dtype
405410 # ========== Step 1: Handle padding ==========
406411 if all_rank_num_tokens is None :
407412 all_rank_num_tokens = [x .shape [0 ]]
@@ -662,7 +667,7 @@ def _forward_chunk_impl(
662667 token_final_scales = token_final_scales ,
663668 x_sf = x_sf ,
664669 ** self ._get_backend_kwargs (
665- router_logits , do_finalize , all_rank_num_tokens , output_dtype
670+ router_logits , do_finalize , all_rank_num_tokens , output_dtype , x
666671 ),
667672 )
668673
@@ -875,12 +880,68 @@ def _is_using_nvlink_two_sided(self) -> bool:
875880 """Check if using NVLinkTwoSided communication strategy"""
876881 return isinstance (self .comm , NVLinkTwoSided )
877882
883+ def _get_nvlink_onesided_moe_output (
884+ self ,
885+ all_rank_num_tokens : Optional [List [int ]],
886+ output_dtype : Optional [torch .dtype ],
887+ ) -> Optional [torch .Tensor ]:
888+ """
889+ Get workspace output buffer for NVLinkOneSided communication backend.
890+
891+ This method handles moe_output allocation for both CutlassFusedMoE and TRTLLMGenFusedMoE
892+ when using NVLinkOneSided communication strategy.
893+
894+ Args:
895+ all_rank_num_tokens: Token counts per rank
896+ output_dtype: Output data type
897+
898+ Returns:
899+ moe_output tensor if NVLinkOneSided is used and backend supports it, None otherwise
900+ """
901+ if not isinstance (self .comm , NVLinkOneSided ):
902+ return None
903+
904+ # Determine workspace dtype and whether backend supports workspace output
905+ workspace_dtype = output_dtype
906+ backend_supports_workspace = False
907+
908+ if isinstance (self .backend , TRTLLMGenFusedMoE ):
909+ # TRTLLMGen specific configuration
910+ self .comm .invalid_token_expert_id = - 1
911+ workspace_dtype = torch .bfloat16
912+ backend_supports_workspace = self .backend .has_w4a8_mxfp4_mxfp8
913+ elif isinstance (self .backend , CutlassFusedMoE ):
914+ # Cutlass always supports workspace output with NVLinkOneSided
915+ backend_supports_workspace = True
916+
917+ if not backend_supports_workspace :
918+ # Ensure payload_in_workspace is False if backend doesn't support it
919+ self .comm .payload_in_workspace = False
920+ return None
921+
922+ # Calculate runtime max tokens per rank
923+ assert all_rank_num_tokens is not None , (
924+ "all_rank_num_tokens must be provided for NVLinkOneSided backend"
925+ )
926+ runtime_max_tokens_per_rank = max (all_rank_num_tokens )
927+
928+ # Get workspace-backed output tensor
929+ moe_output = self .comm .get_combine_payload_tensor_in_workspace (
930+ runtime_max_tokens_per_rank , self .hidden_size , workspace_dtype
931+ )
932+
933+ # Dynamically enable payload_in_workspace for this forward pass
934+ self .comm .payload_in_workspace = True
935+
936+ return moe_output
937+
878938 def _get_backend_kwargs (
879939 self ,
880940 router_logits : Optional [torch .Tensor ] = None ,
881941 do_finalize : bool = True ,
882942 all_rank_num_tokens : Optional [List [int ]] = None ,
883943 output_dtype : Optional [torch .dtype ] = None ,
944+ x : Optional [torch .Tensor ] = None ,
884945 ) -> Dict :
885946 """
886947 Get backend-specific keyword arguments for run_moe
@@ -905,6 +966,8 @@ def _get_backend_kwargs(
905966 router_logits: Router logits tensor (for TRTLLMGen backend)
906967 do_finalize: Whether to finalize output (for TRTLLMGen backend)
907968 all_rank_num_tokens: Token counts per rank (for TRTLLMGen backend moe_output)
969+ output_dtype: Output data type
970+ x: Input tensor (for calculating tuner_num_tokens in Cutlass)
908971
909972 Returns:
910973 Dict: Backend-specific keyword arguments
@@ -917,7 +980,33 @@ def _get_backend_kwargs(
917980
918981 # Cutlass-specific parameters
919982 if self .backend .__class__ == CutlassFusedMoE :
920- pass
983+ # Determine if scaling factors are swizzled based on communication flow
984+ # In post-quant communication (quantize -> dispatch), scaling factors are not swizzled
985+ # In pre-quant communication (dispatch -> quantize), scaling factors are swizzled
986+ supports_post_quant = self .comm is not None and self .comm .supports_post_quant_dispatch ()
987+ kwargs ["is_sf_swizzled" ] = not supports_post_quant
988+ kwargs ["output_dtype" ] = output_dtype
989+
990+ # Prepare additional information for profiling in case padding is applied when using alltoall.
991+ # Only the non-alltoall case is considered for profiling in the warmup phase.
992+ # Therefore, to get the correct tactics during the actual inference, the inputs to the tuner
993+ # should be the same as when not using alltoall.
994+ if self ._is_using_alltoall ():
995+ if all_rank_num_tokens is not None :
996+ kwargs ["tuner_num_tokens" ] = sum (all_rank_num_tokens )
997+ else :
998+ kwargs ["tuner_num_tokens" ] = (
999+ x .shape [0 ] * self .mapping .tp_size if x is not None else None
1000+ )
1001+ kwargs ["tuner_top_k" ] = self .routing_method .top_k
1002+ else :
1003+ kwargs ["tuner_num_tokens" ] = None
1004+ kwargs ["tuner_top_k" ] = None
1005+
1006+ # Get moe_output for NVLinkOneSided backend
1007+ kwargs ["moe_output" ] = self ._get_nvlink_onesided_moe_output (
1008+ all_rank_num_tokens , output_dtype
1009+ )
9211010
9221011 # CuteDSL-specific parameters
9231012 elif self .backend .__class__ == CuteDslFusedMoE :
@@ -940,37 +1029,10 @@ def _get_backend_kwargs(
9401029 kwargs ["router_logits" ] = router_logits_arg
9411030 kwargs ["do_finalize" ] = do_finalize
9421031
943- # moe_output: workspace output buffer for NVLINK one-sided backend
944- # TRTLLMGenFusedMoE only supports workspace output for w4a8_mxfp4_mxfp8 quantization.
945- moe_output = None
946- if isinstance (self .comm , NVLinkOneSided ):
947- # Determine dtype for workspace tensor
948- # TRTLLMGenFusedMoE always uses bfloat16, other backends use output_dtype
949- workspace_dtype = output_dtype
950- if isinstance (self .backend , TRTLLMGenFusedMoE ):
951- self .comm .invalid_token_expert_id = - 1
952- workspace_dtype = torch .bfloat16
953-
954- # Check if backend supports workspace output for current quantization
955- backend_supports_workspace = (
956- isinstance (self .backend , TRTLLMGenFusedMoE )
957- and self .backend .has_w4a8_mxfp4_mxfp8
958- )
959- if backend_supports_workspace :
960- assert all_rank_num_tokens is not None , (
961- "all_rank_num_tokens must be provided for NVLinkOneSided backend with workspace output"
962- )
963- runtime_max_tokens_per_rank = max (all_rank_num_tokens )
964-
965- moe_output = self .comm .get_combine_payload_tensor_in_workspace (
966- runtime_max_tokens_per_rank , self .hidden_size , workspace_dtype
967- )
968- # Dynamically enable payload_in_workspace for this forward pass
969- self .comm .payload_in_workspace = True
970- else :
971- # Ensure payload_in_workspace is False for non-workspace output
972- self .comm .payload_in_workspace = False
973- kwargs ["moe_output" ] = moe_output
1032+ # Get moe_output for NVLinkOneSided backend
1033+ kwargs ["moe_output" ] = self ._get_nvlink_onesided_moe_output (
1034+ all_rank_num_tokens , output_dtype
1035+ )
9741036
9751037 return kwargs
9761038
0 commit comments