diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index 3123783afa4..ba784285cec 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -1,8 +1,11 @@ +import os +from functools import cached_property from typing import Dict, List, Optional, Union import torch from torch import nn +from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe from tensorrt_llm._utils import get_sm_version from ...custom_ops.trtllm_gen_custom_ops import \ @@ -106,10 +109,28 @@ def __init__( assert len( self.initial_local_expert_ids) == self.expert_size_per_partition + self.alltoall_workspace = None + self.alltoall_prepare_workspace = None + if self.enable_alltoall: + MnnvlMemory.initialize() + self.alltoall_workspace = MnnvlMoe.get_moe_workspaces( + model_config.mapping) + self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace( + model_config.mapping) + self._weights_created = False if not model_config.skip_create_weights_in_init: self.create_weights() + @cached_property + def enable_alltoall(self): + mapping = self.mapping + routing_experts = self.routing_method.experts_per_token + return (mapping.moe_ep_size > routing_experts + and mapping.enable_attention_dp and mapping.tp_size > 1 + and os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") != "1" + and MnnvlMemory.supports_mnnvl()) + def _check_configs(self): assert self.has_deepseek_fp8_block_scales \ or self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 \ @@ -175,6 +196,48 @@ def load_weights(self, weights: List[Dict]): def post_load_weights(self): self.quant_method.post_load_weights(self) + def _quantize_for_post_quant_comm(self, x): + """Quantize inputs prior to post-communication (alltoall/allgather). + Returns: (x, x_sf, x_row, x_col) + """ + x_row = x.shape[0] + x_col = x.shape[1] + x_sf = None + if self.has_w4a8_mxfp4_fp8: + x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( + x, self.fc31_input_dequant[0]) + x_row, x_col = x.shape[0], x.shape[1] + elif self.has_nvfp4: + if isinstance(x, Fp4QuantizedTensor): + assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication" + x_row = x.shape[0] + x_col = x.shape[1] * 2 + x, x_sf = x.fp4_tensor, x.scaling_factor + else: + x_row = x.shape[0] + x_col = x.shape[1] + x, x_sf = torch.ops.trtllm.fp4_quantize( + x, self.fc31_input_scale, self.scaling_vector_size, False, + False) + elif self.has_w4a8_mxfp4_mxfp8: + x, x_sf = torch.ops.trtllm.mxfp8_quantize( + x, False, alignment=self.quant_method.weight_alignment) + x_row, x_col = x.shape[0], x.shape[1] + elif self.has_deepseek_fp8_block_scales: + # No change required before communication + pass + elif self.has_w4a16_mxfp4: + pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1] + x = torch.nn.functional.pad(x, (0, pad_size)) + elif self.has_w4a8_nvfp4_fp8: + x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( + x, 1.0 / self.fc31_input_scale) + else: + raise ValueError( + f"unsupported quantization mode for post communication: {self.quant_config.quant_mode}" + ) + return x, x_sf, x_row, x_col + def forward_impl( self, x: Union[torch.Tensor, Fp4QuantizedTensor], @@ -202,55 +265,80 @@ def forward_impl( topk_group = None routed_scaling_factor = None - run_post_quant_allgather = self.use_dp and self.parallel_size > 1 + run_post_quant_allgather = (self.use_dp and self.parallel_size > 1 + and not self.enable_alltoall) + post_quant_comm = run_post_quant_allgather or self.enable_alltoall x_sf = None token_selected_experts = None token_final_scales = None x_row = x.shape[0] x_col = x.shape[1] - if run_post_quant_allgather: - # apply routing + token_count = x.shape[0] + alltoall_info = None + + if post_quant_comm: token_selected_experts, token_final_scales = self.routing_method.apply( router_logits) - token_final_scales = token_final_scales.to(torch.bfloat16) - assert token_final_scales.dtype == torch.bfloat16 - assert token_selected_experts.dtype == torch.int32 - # quantize inputs - if self.has_w4a8_mxfp4_fp8: - x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( - x, self.fc31_input_dequant[0]) - # Update x_row and x_col to the padded shape - x_row, x_col = x.shape[0], x.shape[1] - elif self.has_nvfp4: - if isinstance(x, Fp4QuantizedTensor): - assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication" - x_row = x.shape[0] - # note: we use uint8 to store 2 fp4 values - x_col = x.shape[1] * 2 - x, x_sf = x.fp4_tensor, x.scaling_factor - else: - x_row = x.shape[0] - x_col = x.shape[1] - x, x_sf = torch.ops.trtllm.fp4_quantize( - x, self.fc31_input_scale, self.scaling_vector_size, - False, False) - elif self.has_w4a8_mxfp4_mxfp8: - x, x_sf = torch.ops.trtllm.mxfp8_quantize( - x, False, alignment=self.quant_method.weight_alignment) - # Update x_row and x_col to the padded shape - x_row, x_col = x.shape[0], x.shape[1] - elif self.has_deepseek_fp8_block_scales: - pass - elif self.has_w4a16_mxfp4: - pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1] - x = torch.nn.functional.pad(x, (0, pad_size)) + token_selected_experts = token_selected_experts.to(torch.int32) + if token_final_scales is not None: + token_final_scales = token_final_scales.to(torch.bfloat16) + + x, x_sf, x_row, x_col = self._quantize_for_post_quant_comm(x) + + if self.enable_alltoall: + assert all_rank_num_tokens is not None, "all_rank_num_tokens required for alltoall" + + max_num_token = max( + all_rank_num_tokens) if all_rank_num_tokens else token_count + + if token_final_scales is None: + token_final_scales = torch.ones_like(token_selected_experts, + dtype=torch.float32) else: - raise ValueError( - f"unsupported quantization mode with run_post_quant_allgather: {self.quant_config.quant_mode}" - ) + token_final_scales = token_final_scales.to(torch.float32) + + assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized" + alltoall_info, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( + token_selected_experts, + None, + self.alltoall_prepare_workspace, + max_num_token, + self.ep_rank, + self.ep_size, + self.num_experts, + self.num_slots, + top_k, + ) - #allgather for attention DP + if x_sf is not None: + x_sf = x_sf.view(x_row, ceil_div(x_col, + self.scaling_vector_size)) + + x, x_sf, token_selected_experts, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv( + [x, x_sf, token_selected_experts, token_final_scales], + alltoall_info, + self.alltoall_workspace, + self.ep_rank, + self.ep_size, + ) + + torch.ops.trtllm.memset_expert_ids( + token_selected_experts, + alltoall_info.recv_rank_count_cumsum, + max_num_token, + top_k, + self.num_slots, + self.ep_size, + ) + + if x_sf is not None: + x_sf = x_sf.flatten() + + if token_final_scales is not None: + token_final_scales = token_final_scales.to(torch.bfloat16) + + elif run_post_quant_allgather: if x_sf is not None: x_sf = x_sf.view(x_row, ceil_div(x_col, self.scaling_vector_size)) @@ -265,6 +353,9 @@ def forward_impl( if x_sf is not None: x_sf = x_sf.flatten() + router_logits_arg = router_logits if not post_quant_comm else None + routing_bias_arg = routing_bias if not post_quant_comm else None + # TODO: since routing kernel is integrated into moe_runner for fp8, # here we just route the I/Os for moe_runner if self.has_deepseek_fp8_block_scales: @@ -272,8 +363,8 @@ def forward_impl( x_val, x_scale = torch.ops.trtllm.fp8_quantize_1x128(x) final_hidden_states = torch.ops.trtllm.fp8_block_scale_moe_runner( - router_logits if not run_post_quant_allgather else None, - routing_bias if not run_post_quant_allgather else None, + router_logits_arg, + routing_bias_arg, x_val, x_scale, self.w3_w1_weight, @@ -297,7 +388,7 @@ def forward_impl( scale_factor_use_ue8m0 = False is_scale_factor_swizzled = False # use linear layout here - if not run_post_quant_allgather: + if not post_quant_comm: hidden_states_fp4, hidden_states_scale_linear_fp4 = ( torch.ops.trtllm.fp4_quantize( x, @@ -310,8 +401,8 @@ def forward_impl( hidden_states_fp4, hidden_states_scale_linear_fp4 = x, x_sf outputs = torch.ops.trtllm.fp4_block_scale_moe_runner( - router_logits if not run_post_quant_allgather else None, - routing_bias if not run_post_quant_allgather else None, + router_logits_arg, + routing_bias_arg, hidden_states_fp4, hidden_states_scale_linear_fp4.view(torch.float8_e4m3fn), self.w3_w1_weight, @@ -343,7 +434,7 @@ def forward_impl( final_hidden_states = outputs[0] elif self.has_w4a16_mxfp4: assert x.dtype == torch.bfloat16 - if not run_post_quant_allgather: + if not post_quant_comm: pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1] x = torch.nn.functional.pad(x, (0, pad_size)) else: @@ -352,8 +443,8 @@ def forward_impl( intermediate_size_per_partition_padded = self.w3_w1_weight.shape[ -2] // 2 final_hidden_states = torch.ops.trtllm.bf16_mxe2m1_block_scale_moe_runner( - router_logits if not run_post_quant_allgather else None, - routing_bias if not run_post_quant_allgather else None, + router_logits_arg, + routing_bias_arg, x, self.w3_w1_weight, self.w3_w1_weight_scale, @@ -383,15 +474,15 @@ def forward_impl( hidden_size].contiguous() elif self.has_w4a8_nvfp4_fp8: - if not run_post_quant_allgather: + if not post_quant_comm: hidden_states_fp8, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( x, 1.0 / self.fc31_input_scale) else: hidden_states_fp8 = x outputs = torch.ops.trtllm.fp8_fp4_block_scale_moe_runner( - router_logits, - routing_bias, + router_logits_arg, + routing_bias_arg, hidden_states_fp8, self.w3_w1_weight, self.w3_w1_weight_scale.view(torch.float8_e4m3fn), @@ -423,7 +514,7 @@ def forward_impl( final_hidden_states = outputs[0] elif self.has_w4a8_mxfp4_fp8: pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1] - if not run_post_quant_allgather: + if not post_quant_comm: x = torch.nn.functional.pad(x, (0, pad_size)) x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( x, self.fc31_input_gate_dequant[0]) @@ -433,8 +524,8 @@ def forward_impl( -2] // 2 final_hidden_states = torch.ops.trtllm.e4m3_mxe2m1_block_scale_moe_runner( - router_logits if not run_post_quant_allgather else None, - routing_bias if not run_post_quant_allgather else None, + router_logits_arg, + routing_bias_arg, x, self.w3_w1_weight, self.w3_w1_weight_scale, @@ -466,7 +557,7 @@ def forward_impl( final_hidden_states = final_hidden_states[:, :self. hidden_size].contiguous() elif self.has_w4a8_mxfp4_mxfp8: - if not run_post_quant_allgather: + if not post_quant_comm: # TRTLLM-Gen uses linear SF layout for the mxfp8 input. mxfp8_x, sf = torch.ops.trtllm.mxfp8_quantize( x, False, alignment=self.quant_method.weight_alignment) @@ -477,8 +568,8 @@ def forward_impl( -2] // 2 final_hidden_states = torch.ops.trtllm.mxe4m3_mxe2m1_block_scale_moe_runner( - router_logits if not run_post_quant_allgather else None, - routing_bias if not run_post_quant_allgather else None, + router_logits_arg, + routing_bias_arg, mxfp8_x, sf, self.w3_w1_weight, @@ -511,6 +602,18 @@ def forward_impl( "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_mxfp8 and w4a8_mxfp4_fp8 dtypes." ) + # Combine results if using alltoall + if self.enable_alltoall and alltoall_info is not None: + final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine( + final_hidden_states, + alltoall_info, + self.alltoall_workspace, + ep_rank=self.ep_rank, + ep_size=self.ep_size, + top_k=top_k, + token_count=token_count, + ) + final_hidden_states = self.reducescatter_or_allreduce( final_hidden_states, all_rank_num_tokens=all_rank_num_tokens,