1+ import os
2+ from functools import cached_property
13from typing import Dict , List , Optional , Union
24
35import torch
46from torch import nn
57
8+ from tensorrt_llm ._mnnvl_utils import MnnvlMemory , MnnvlMoe
69from tensorrt_llm ._utils import get_sm_version
710
811from ...custom_ops .trtllm_gen_custom_ops import \
@@ -106,10 +109,28 @@ def __init__(
106109 assert len (
107110 self .initial_local_expert_ids ) == self .expert_size_per_partition
108111
112+ self .alltoall_workspace = None
113+ self .alltoall_prepare_workspace = None
114+ if self .enable_alltoall :
115+ MnnvlMemory .initialize ()
116+ self .alltoall_workspace = MnnvlMoe .get_moe_workspaces (
117+ model_config .mapping )
118+ self .alltoall_prepare_workspace = MnnvlMoe .get_moe_prepare_workspace (
119+ model_config .mapping )
120+
109121 self ._weights_created = False
110122 if not model_config .skip_create_weights_in_init :
111123 self .create_weights ()
112124
125+ @cached_property
126+ def enable_alltoall (self ):
127+ mapping = self .mapping
128+ routing_experts = self .routing_method .experts_per_token
129+ return (mapping .moe_ep_size > routing_experts
130+ and mapping .enable_attention_dp and mapping .tp_size > 1
131+ and os .environ .get ("TRTLLM_MOE_DISABLE_ALLTOALLV" , "0" ) != "1"
132+ and MnnvlMemory .supports_mnnvl ())
133+
113134 def _check_configs (self ):
114135 assert self .has_deepseek_fp8_block_scales \
115136 or self .has_nvfp4 or self .has_w4a16_mxfp4 or self .has_w4a8_nvfp4_fp8 \
@@ -175,6 +196,48 @@ def load_weights(self, weights: List[Dict]):
175196 def post_load_weights (self ):
176197 self .quant_method .post_load_weights (self )
177198
199+ def _quantize_for_post_quant_comm (self , x ):
200+ """Quantize inputs prior to post-communication (alltoall/allgather).
201+ Returns: (x, x_sf, x_row, x_col)
202+ """
203+ x_row = x .shape [0 ]
204+ x_col = x .shape [1 ]
205+ x_sf = None
206+ if self .has_w4a8_mxfp4_fp8 :
207+ x , _ = torch .ops .tensorrt_llm .static_quantize_e4m3_per_tensor (
208+ x , self .fc31_input_dequant [0 ])
209+ x_row , x_col = x .shape [0 ], x .shape [1 ]
210+ elif self .has_nvfp4 :
211+ if isinstance (x , Fp4QuantizedTensor ):
212+ assert not x .is_sf_swizzled , "Fp4QuantizedTensor should not be swizzled before communication"
213+ x_row = x .shape [0 ]
214+ x_col = x .shape [1 ] * 2
215+ x , x_sf = x .fp4_tensor , x .scaling_factor
216+ else :
217+ x_row = x .shape [0 ]
218+ x_col = x .shape [1 ]
219+ x , x_sf = torch .ops .trtllm .fp4_quantize (
220+ x , self .fc31_input_scale , self .scaling_vector_size , False ,
221+ False )
222+ elif self .has_w4a8_mxfp4_mxfp8 :
223+ x , x_sf = torch .ops .trtllm .mxfp8_quantize (
224+ x , False , alignment = self .quant_method .weight_alignment )
225+ x_row , x_col = x .shape [0 ], x .shape [1 ]
226+ elif self .has_deepseek_fp8_block_scales :
227+ # No change required before communication
228+ pass
229+ elif self .has_w4a16_mxfp4 :
230+ pad_size = self .w3_w1_weight .shape [- 1 ] * 2 - x .shape [- 1 ]
231+ x = torch .nn .functional .pad (x , (0 , pad_size ))
232+ elif self .has_w4a8_nvfp4_fp8 :
233+ x , _ = torch .ops .tensorrt_llm .static_quantize_e4m3_per_tensor (
234+ x , 1.0 / self .fc31_input_scale )
235+ else :
236+ raise ValueError (
237+ f"unsupported quantization mode for post communication: { self .quant_config .quant_mode } "
238+ )
239+ return x , x_sf , x_row , x_col
240+
178241 def forward_impl (
179242 self ,
180243 x : Union [torch .Tensor , Fp4QuantizedTensor ],
@@ -202,55 +265,80 @@ def forward_impl(
202265 topk_group = None
203266 routed_scaling_factor = None
204267
205- run_post_quant_allgather = self .use_dp and self .parallel_size > 1
268+ run_post_quant_allgather = (self .use_dp and self .parallel_size > 1
269+ and not self .enable_alltoall )
270+ post_quant_comm = run_post_quant_allgather or self .enable_alltoall
206271
207272 x_sf = None
208273 token_selected_experts = None
209274 token_final_scales = None
210275 x_row = x .shape [0 ]
211276 x_col = x .shape [1 ]
212- if run_post_quant_allgather :
213- # apply routing
277+ token_count = x .shape [0 ]
278+ alltoall_info = None
279+
280+ if post_quant_comm :
214281 token_selected_experts , token_final_scales = self .routing_method .apply (
215282 router_logits )
216- token_final_scales = token_final_scales .to (torch .bfloat16 )
217- assert token_final_scales .dtype == torch .bfloat16
218- assert token_selected_experts .dtype == torch .int32
219- # quantize inputs
220- if self .has_w4a8_mxfp4_fp8 :
221- x , _ = torch .ops .tensorrt_llm .static_quantize_e4m3_per_tensor (
222- x , self .fc31_input_dequant [0 ])
223- # Update x_row and x_col to the padded shape
224- x_row , x_col = x .shape [0 ], x .shape [1 ]
225- elif self .has_nvfp4 :
226- if isinstance (x , Fp4QuantizedTensor ):
227- assert not x .is_sf_swizzled , "Fp4QuantizedTensor should not be swizzled before communication"
228- x_row = x .shape [0 ]
229- # note: we use uint8 to store 2 fp4 values
230- x_col = x .shape [1 ] * 2
231- x , x_sf = x .fp4_tensor , x .scaling_factor
232- else :
233- x_row = x .shape [0 ]
234- x_col = x .shape [1 ]
235- x , x_sf = torch .ops .trtllm .fp4_quantize (
236- x , self .fc31_input_scale , self .scaling_vector_size ,
237- False , False )
238- elif self .has_w4a8_mxfp4_mxfp8 :
239- x , x_sf = torch .ops .trtllm .mxfp8_quantize (
240- x , False , alignment = self .quant_method .weight_alignment )
241- # Update x_row and x_col to the padded shape
242- x_row , x_col = x .shape [0 ], x .shape [1 ]
243- elif self .has_deepseek_fp8_block_scales :
244- pass
245- elif self .has_w4a16_mxfp4 :
246- pad_size = self .w3_w1_weight .shape [- 1 ] * 2 - x .shape [- 1 ]
247- x = torch .nn .functional .pad (x , (0 , pad_size ))
283+ token_selected_experts = token_selected_experts .to (torch .int32 )
284+ if token_final_scales is not None :
285+ token_final_scales = token_final_scales .to (torch .bfloat16 )
286+
287+ x , x_sf , x_row , x_col = self ._quantize_for_post_quant_comm (x )
288+
289+ if self .enable_alltoall :
290+ assert all_rank_num_tokens is not None , "all_rank_num_tokens required for alltoall"
291+
292+ max_num_token = max (
293+ all_rank_num_tokens ) if all_rank_num_tokens else token_count
294+
295+ if token_final_scales is None :
296+ token_final_scales = torch .ones_like (token_selected_experts ,
297+ dtype = torch .float32 )
248298 else :
249- raise ValueError (
250- f"unsupported quantization mode with run_post_quant_allgather: { self .quant_config .quant_mode } "
251- )
299+ token_final_scales = token_final_scales .to (torch .float32 )
300+
301+ assert self .alltoall_prepare_workspace is not None , "alltoall_prepare_workspace should be initialized"
302+ alltoall_info , _ = MnnvlMoe .mnnvl_moe_alltoallv_prepare_without_allgather (
303+ token_selected_experts ,
304+ None ,
305+ self .alltoall_prepare_workspace ,
306+ max_num_token ,
307+ self .ep_rank ,
308+ self .ep_size ,
309+ self .num_experts ,
310+ self .num_slots ,
311+ top_k ,
312+ )
252313
253- #allgather for attention DP
314+ if x_sf is not None :
315+ x_sf = x_sf .view (x_row , ceil_div (x_col ,
316+ self .scaling_vector_size ))
317+
318+ x , x_sf , token_selected_experts , token_final_scales = MnnvlMoe .mnnvl_moe_alltoallv (
319+ [x , x_sf , token_selected_experts , token_final_scales ],
320+ alltoall_info ,
321+ self .alltoall_workspace ,
322+ self .ep_rank ,
323+ self .ep_size ,
324+ )
325+
326+ torch .ops .trtllm .memset_expert_ids (
327+ token_selected_experts ,
328+ alltoall_info .recv_rank_count_cumsum ,
329+ max_num_token ,
330+ top_k ,
331+ self .num_slots ,
332+ self .ep_size ,
333+ )
334+
335+ if x_sf is not None :
336+ x_sf = x_sf .flatten ()
337+
338+ if token_final_scales is not None :
339+ token_final_scales = token_final_scales .to (torch .bfloat16 )
340+
341+ elif run_post_quant_allgather :
254342 if x_sf is not None :
255343 x_sf = x_sf .view (x_row , ceil_div (x_col ,
256344 self .scaling_vector_size ))
@@ -265,15 +353,18 @@ def forward_impl(
265353 if x_sf is not None :
266354 x_sf = x_sf .flatten ()
267355
356+ router_logits_arg = router_logits if not post_quant_comm else None
357+ routing_bias_arg = routing_bias if not post_quant_comm else None
358+
268359 # TODO: since routing kernel is integrated into moe_runner for fp8,
269360 # here we just route the I/Os for moe_runner
270361 if self .has_deepseek_fp8_block_scales :
271362 assert do_finalize , "fp8_block_scale_moe_runner does not support do_finalize=False"
272363 x_val , x_scale = torch .ops .trtllm .fp8_quantize_1x128 (x )
273364
274365 final_hidden_states = torch .ops .trtllm .fp8_block_scale_moe_runner (
275- router_logits if not run_post_quant_allgather else None ,
276- routing_bias if not run_post_quant_allgather else None ,
366+ router_logits_arg ,
367+ routing_bias_arg ,
277368 x_val ,
278369 x_scale ,
279370 self .w3_w1_weight ,
@@ -297,7 +388,7 @@ def forward_impl(
297388 scale_factor_use_ue8m0 = False
298389 is_scale_factor_swizzled = False # use linear layout here
299390
300- if not run_post_quant_allgather :
391+ if not post_quant_comm :
301392 hidden_states_fp4 , hidden_states_scale_linear_fp4 = (
302393 torch .ops .trtllm .fp4_quantize (
303394 x ,
@@ -310,8 +401,8 @@ def forward_impl(
310401 hidden_states_fp4 , hidden_states_scale_linear_fp4 = x , x_sf
311402
312403 outputs = torch .ops .trtllm .fp4_block_scale_moe_runner (
313- router_logits if not run_post_quant_allgather else None ,
314- routing_bias if not run_post_quant_allgather else None ,
404+ router_logits_arg ,
405+ routing_bias_arg ,
315406 hidden_states_fp4 ,
316407 hidden_states_scale_linear_fp4 .view (torch .float8_e4m3fn ),
317408 self .w3_w1_weight ,
@@ -343,7 +434,7 @@ def forward_impl(
343434 final_hidden_states = outputs [0 ]
344435 elif self .has_w4a16_mxfp4 :
345436 assert x .dtype == torch .bfloat16
346- if not run_post_quant_allgather :
437+ if not post_quant_comm :
347438 pad_size = self .w3_w1_weight .shape [- 1 ] * 2 - x .shape [- 1 ]
348439 x = torch .nn .functional .pad (x , (0 , pad_size ))
349440 else :
@@ -352,8 +443,8 @@ def forward_impl(
352443 intermediate_size_per_partition_padded = self .w3_w1_weight .shape [
353444 - 2 ] // 2
354445 final_hidden_states = torch .ops .trtllm .bf16_mxe2m1_block_scale_moe_runner (
355- router_logits if not run_post_quant_allgather else None ,
356- routing_bias if not run_post_quant_allgather else None ,
446+ router_logits_arg ,
447+ routing_bias_arg ,
357448 x ,
358449 self .w3_w1_weight ,
359450 self .w3_w1_weight_scale ,
@@ -383,15 +474,15 @@ def forward_impl(
383474 hidden_size ].contiguous ()
384475 elif self .has_w4a8_nvfp4_fp8 :
385476
386- if not run_post_quant_allgather :
477+ if not post_quant_comm :
387478 hidden_states_fp8 , _ = torch .ops .tensorrt_llm .static_quantize_e4m3_per_tensor (
388479 x , 1.0 / self .fc31_input_scale )
389480 else :
390481 hidden_states_fp8 = x
391482
392483 outputs = torch .ops .trtllm .fp8_fp4_block_scale_moe_runner (
393- router_logits ,
394- routing_bias ,
484+ router_logits_arg ,
485+ routing_bias_arg ,
395486 hidden_states_fp8 ,
396487 self .w3_w1_weight ,
397488 self .w3_w1_weight_scale .view (torch .float8_e4m3fn ),
@@ -423,7 +514,7 @@ def forward_impl(
423514 final_hidden_states = outputs [0 ]
424515 elif self .has_w4a8_mxfp4_fp8 :
425516 pad_size = self .w3_w1_weight .shape [- 1 ] * 2 - x .shape [- 1 ]
426- if not run_post_quant_allgather :
517+ if not post_quant_comm :
427518 x = torch .nn .functional .pad (x , (0 , pad_size ))
428519 x , _ = torch .ops .tensorrt_llm .static_quantize_e4m3_per_tensor (
429520 x , self .fc31_input_gate_dequant [0 ])
@@ -433,8 +524,8 @@ def forward_impl(
433524 - 2 ] // 2
434525
435526 final_hidden_states = torch .ops .trtllm .e4m3_mxe2m1_block_scale_moe_runner (
436- router_logits if not run_post_quant_allgather else None ,
437- routing_bias if not run_post_quant_allgather else None ,
527+ router_logits_arg ,
528+ routing_bias_arg ,
438529 x ,
439530 self .w3_w1_weight ,
440531 self .w3_w1_weight_scale ,
@@ -466,7 +557,7 @@ def forward_impl(
466557 final_hidden_states = final_hidden_states [:, :self .
467558 hidden_size ].contiguous ()
468559 elif self .has_w4a8_mxfp4_mxfp8 :
469- if not run_post_quant_allgather :
560+ if not post_quant_comm :
470561 # TRTLLM-Gen uses linear SF layout for the mxfp8 input.
471562 mxfp8_x , sf = torch .ops .trtllm .mxfp8_quantize (
472563 x , False , alignment = self .quant_method .weight_alignment )
@@ -477,8 +568,8 @@ def forward_impl(
477568 - 2 ] // 2
478569
479570 final_hidden_states = torch .ops .trtllm .mxe4m3_mxe2m1_block_scale_moe_runner (
480- router_logits if not run_post_quant_allgather else None ,
481- routing_bias if not run_post_quant_allgather else None ,
571+ router_logits_arg ,
572+ routing_bias_arg ,
482573 mxfp8_x ,
483574 sf ,
484575 self .w3_w1_weight ,
@@ -511,6 +602,18 @@ def forward_impl(
511602 "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_mxfp8 and w4a8_mxfp4_fp8 dtypes."
512603 )
513604
605+ # Combine results if using alltoall
606+ if self .enable_alltoall and alltoall_info is not None :
607+ final_hidden_states = MnnvlMoe .mnnvl_moe_alltoallv_combine (
608+ final_hidden_states ,
609+ alltoall_info ,
610+ self .alltoall_workspace ,
611+ ep_rank = self .ep_rank ,
612+ ep_size = self .ep_size ,
613+ top_k = top_k ,
614+ token_count = token_count ,
615+ )
616+
514617 final_hidden_states = self .reducescatter_or_allreduce (
515618 final_hidden_states ,
516619 all_rank_num_tokens = all_rank_num_tokens ,
0 commit comments