@@ -192,8 +192,12 @@ def __init__(
192192 model_config .mapping )
193193 self .deep_ep_buffer .reserve (hidden_size , dtype )
194194 elif self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency :
195- self .deep_ep_max_num_tokens = min (model_config .max_num_tokens ,
196- self .moe_max_num_tokens )
195+ self .deep_ep_max_num_tokens = int (
196+ os .environ .get (
197+ "TRTLLM_DEEP_EP_TOKEN_LIMIT" ,
198+ str (
199+ min (model_config .max_num_tokens ,
200+ self .moe_max_num_tokens ))))
197201 self .deep_ep_buffer = buffer_pool .get_low_latency_buffer (
198202 model_config .mapping )
199203 self .deep_ep_buffer .reserve (self .deep_ep_max_num_tokens ,
@@ -274,6 +278,25 @@ def enable_alltoall(self):
274278 """
275279 return self .alltoall_method_type != AlltoallMethodType .NotEnabled
276280
281+ def calculate_num_chunks (self , all_rank_num_tokens : List [int ]) -> int :
282+ num_rows = sum (all_rank_num_tokens )
283+ return (num_rows + self .moe_max_num_tokens -
284+ 1 ) // self .moe_max_num_tokens
285+
286+ def can_use_alltoall (self , input , all_rank_num_tokens ):
287+ # Disable alltoall when chunking is used
288+ if self .calculate_num_chunks (all_rank_num_tokens ) > 1 :
289+ return False
290+
291+ num_tokens = input .shape [0 ]
292+
293+ # For DeepEPLowLatency, check if tokens exceed the threshold
294+ if (self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency
295+ and num_tokens > self .deep_ep_max_num_tokens ):
296+ return False
297+
298+ return self .enable_alltoall
299+
277300 def _get_quant_method (self ):
278301 if self .quant_config is not None and self .quant_config .layer_quant_mode .has_any_quant (
279302 exclude_kv_cache = True ):
@@ -316,11 +339,12 @@ def dummy_allreduce(self):
316339 def reducescatter_or_allreduce (
317340 self ,
318341 inputs ,
342+ use_all_to_all : bool ,
319343 all_rank_num_tokens : Optional [List [int ]] = None ,
320344 use_dp_padding : Optional [bool ] = None ,
321345 ):
322346 outputs = inputs
323- if not self . enable_alltoall :
347+ if not use_all_to_all :
324348 if self .enable_dummy_allreduce :
325349 self .dummy_allreduce ()
326350 outputs = reducescatter (
@@ -334,6 +358,7 @@ def forward_chunk(
334358 self ,
335359 x : Union [torch .Tensor , Fp4QuantizedTensor ],
336360 router_logits : torch .Tensor ,
361+ use_all_to_all : bool ,
337362 output_dtype : Optional [torch .dtype ] = None ,
338363 all_rank_num_tokens : Optional [List [int ]] = None ,
339364 all_rank_max_num_tokens : Optional [int ] = None ,
@@ -382,7 +407,7 @@ def forward_chunk(
382407 ) and is_first_call :
383408 self .layer_load_balancer .maybe_cudagraph_done_wait ()
384409
385- use_allgather = not self . enable_alltoall
410+ use_allgather = not use_all_to_all
386411
387412 loadbalancer_local_statistic_info = None
388413 gathered_loadbalancer_local_statistic_info = None
@@ -391,7 +416,7 @@ def forward_chunk(
391416 token_selected_slots = token_selected_experts
392417 else :
393418 if not self .layer_load_balancer .is_static_routing (
394- ) and self . enable_alltoall :
419+ ) and use_all_to_all :
395420 self .layer_load_balancer .local_statistic (
396421 token_selected_experts ,
397422 is_first_stage = is_first_call ,
@@ -400,7 +425,7 @@ def forward_chunk(
400425 token_selected_experts , self .use_dp )
401426 if not self .layer_load_balancer .is_static_routing ():
402427 # split into two part to get possible overlap with load balancer routing
403- if self . enable_alltoall :
428+ if use_all_to_all :
404429 if is_last_call :
405430 loadbalancer_local_statistic_info = self .layer_load_balancer .get_local_statistic_tensor (
406431 )
@@ -412,7 +437,9 @@ def forward_chunk(
412437 ExpertStatistic .set_layer (self .layer_idx )
413438 ExpertStatistic .maybe_add_info (self .num_slots , token_selected_slots )
414439
415- if self .enable_alltoall :
440+ # If alltoall is disabled, we need also disable use_postquant_alltoall
441+ use_postquant_alltoall = self .use_postquant_alltoall and use_all_to_all
442+ if use_all_to_all :
416443 if self .alltoall_method_type == AlltoallMethodType .MNNVL :
417444 if self .enable_dummy_allreduce :
418445 self .dummy_allreduce ()
@@ -423,15 +450,16 @@ def forward_chunk(
423450 x ,
424451 token_selected_slots ,
425452 token_final_scales ,
453+ use_postquant_alltoall ,
426454 loadbalancer_local_statistic_info )
427455 elif self .alltoall_method_type == AlltoallMethodType .DeepEP :
428- if not self . use_postquant_alltoall :
456+ if not use_postquant_alltoall :
429457 x , recv_topk_idx , token_final_scales , num_recv_tokens_per_expert_list , deep_ep_handle = \
430458 self .deep_ep_buffer .dispatch (x , token_selected_slots .to (torch .int64 ), token_final_scales , self .num_slots )
431459 padded , x , _ , recv_topk_idx , token_final_scales = self .pad_empty_recv_tensors (
432460 x , None , recv_topk_idx , token_final_scales )
433461 elif self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency :
434- if not self . use_postquant_alltoall :
462+ if not use_postquant_alltoall :
435463 deep_ep_topk_idx = token_selected_slots .to (torch .int64 )
436464 deep_ep_topk_weights = token_final_scales
437465 x , recv_expert_count , deep_ep_handle = \
@@ -471,7 +499,7 @@ def forward_chunk(
471499 x , _ = torch .ops .tensorrt_llm .static_quantize_e4m3_per_tensor (
472500 x , self .fc31_input_dequant )
473501 elif self .has_nvfp4 :
474- if use_allgather or self . use_postquant_alltoall :
502+ if use_allgather or use_postquant_alltoall :
475503 if isinstance (x , Fp4QuantizedTensor ):
476504 if use_allgather :
477505 assert not x .is_sf_swizzled , "Fp4QuantizedTensor should not be swizzled before allgather"
@@ -527,7 +555,7 @@ def forward_chunk(
527555
528556 if self .layer_load_balancer and not self .layer_load_balancer .is_static_routing (
529557 ):
530- if self . enable_alltoall :
558+ if use_all_to_all :
531559 if is_last_call :
532560 gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info .view (
533561 (self .mapping .moe_ep_size , self .num_experts ))
@@ -547,7 +575,7 @@ def forward_chunk(
547575 cluster_rank = self .cluster_rank
548576 quant_scales = self .quant_scales
549577
550- if self . use_postquant_alltoall :
578+ if use_postquant_alltoall :
551579 if x_sf is not None and self .has_nvfp4 :
552580 assert not x_is_sf_swizzled , "Fp4 scaling factor should not be swizzled before Alltoall"
553581 if self .alltoall_method_type == AlltoallMethodType .MNNVL :
@@ -640,7 +668,7 @@ def forward_chunk(
640668 f"Not available alltoall method type: { self .alltoall_method_type !r} "
641669 )
642670
643- if self . enable_alltoall :
671+ if use_all_to_all :
644672 # Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
645673 # TODO: remove the adapter by changing APIs
646674 if self .alltoall_method_type == AlltoallMethodType .DeepEP :
@@ -666,7 +694,7 @@ def forward_chunk(
666694 ep_rank = ep_rank ,
667695 cluster_size = cluster_size ,
668696 cluster_rank = cluster_rank ,
669- enable_alltoall = self . enable_alltoall ,
697+ enable_alltoall = use_all_to_all ,
670698 use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale ,
671699 use_w4a8_group_scaling = use_w4a8_group_scaling ,
672700 min_latency_mode = False ,
@@ -681,7 +709,7 @@ def forward_chunk(
681709 # Otherwise, the output should be unpacked as a single tensor.
682710 final_hidden_states = final_hidden_states [0 ]
683711
684- if self . enable_alltoall :
712+ if use_all_to_all :
685713 if self .alltoall_method_type == AlltoallMethodType .MNNVL :
686714 if self .enable_dummy_allreduce :
687715 self .dummy_allreduce ()
@@ -737,11 +765,10 @@ def forward(
737765 ) -> torch .Tensor :
738766 assert all_rank_num_tokens is not None
739767 assert use_dp_padding is not None
740- num_rows = sum (all_rank_num_tokens )
741768
742769 # in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
743- num_chunks = ( num_rows + self .moe_max_num_tokens -
744- 1 ) // self .moe_max_num_tokens
770+ num_chunks = self .calculate_num_chunks ( all_rank_num_tokens )
771+ use_all_to_all = self .can_use_alltoall ( x , all_rank_num_tokens )
745772
746773 if use_dp_padding :
747774 all_rank_num_tokens_padded = [all_rank_max_num_tokens
@@ -754,13 +781,15 @@ def forward(
754781 outputs = self .forward_chunk (
755782 x ,
756783 router_logits ,
784+ use_all_to_all ,
757785 output_dtype ,
758786 all_rank_num_tokens = all_rank_num_tokens_padded ,
759787 all_rank_max_num_tokens = all_rank_max_num_tokens ,
760788 use_dp_padding = use_dp_padding ,
761789 repeating_info = (is_first_call , is_last_call ))
762790 outputs = self .reducescatter_or_allreduce (
763791 outputs ,
792+ use_all_to_all ,
764793 all_rank_num_tokens = all_rank_num_tokens_padded ,
765794 use_dp_padding = use_dp_padding )
766795 else :
@@ -782,7 +811,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
782811 all_rank_max_num_tokens_list = split_chunk (all_rank_max_num_tokens ,
783812 num_chunks )
784813 chunk_size_list = all_rank_chunk_size_list [self .rank ]
785- if self . enable_alltoall :
814+ if use_all_to_all :
786815 all_rank_num_tokens_list = [[
787816 1 if val == 0 else val for val in val_list
788817 ] for val_list in all_rank_num_tokens_list ]
@@ -794,7 +823,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
794823 x_list = x .split (chunk_size_list )
795824 router_logits_list = router_logits .split (chunk_size_list )
796825
797- if not self . enable_alltoall :
826+ if not use_all_to_all :
798827 self .event_dict [EventType .Main ].record ()
799828 with torch .cuda .stream (self .aux_stream ):
800829 self .event_dict [EventType .Main ].wait ()
@@ -805,12 +834,13 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
805834 zip (x_list , router_logits_list )):
806835 is_first_call = idx_chunk == 0 and self .repeat_idx == 0
807836 is_last_call = idx_chunk == num_chunks - 1 and self .repeat_idx == self .repeat_count - 1
808- if not self . enable_alltoall :
837+ if not use_all_to_all :
809838 if idx_chunk % 2 == 0 :
810839 with torch .cuda .stream (self .aux_stream ):
811840 outputs = self .forward_chunk (
812841 x ,
813842 router_logits ,
843+ use_all_to_all ,
814844 all_rank_num_tokens = all_rank_num_tokens_list [
815845 idx_chunk ],
816846 all_rank_max_num_tokens =
@@ -820,13 +850,15 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
820850 if idx_chunk > 0 :
821851 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
822852 outputs_list [- 1 ],
853+ use_all_to_all ,
823854 all_rank_num_tokens = all_rank_num_tokens_list [
824855 idx_chunk - 1 ],
825856 use_dp_padding = use_dp_padding )
826857 else :
827858 outputs = self .forward_chunk (
828859 x ,
829860 router_logits ,
861+ use_all_to_all ,
830862 all_rank_num_tokens = all_rank_num_tokens_list [
831863 idx_chunk ],
832864 all_rank_max_num_tokens = all_rank_max_num_tokens_list [
@@ -836,29 +868,33 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
836868 with torch .cuda .stream (self .aux_stream ):
837869 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
838870 outputs_list [- 1 ],
871+ use_all_to_all ,
839872 all_rank_num_tokens = all_rank_num_tokens_list [
840873 idx_chunk - 1 ],
841874 use_dp_padding = use_dp_padding )
842875 else :
843876 outputs = self .forward_chunk (
844877 x ,
845878 router_logits ,
879+ use_all_to_all ,
846880 all_rank_num_tokens = all_rank_num_tokens_list [idx_chunk ],
847881 all_rank_max_num_tokens = all_rank_max_num_tokens_list [
848882 idx_chunk ],
849883 repeating_info = (is_first_call , is_last_call ))
850884
851885 outputs_list .append (outputs )
852- if not self . enable_alltoall :
886+ if not use_all_to_all :
853887 if num_chunks % 2 == 0 :
854888 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
855889 outputs_list [- 1 ],
890+ use_all_to_all ,
856891 all_rank_num_tokens = all_rank_num_tokens_list [- 1 ],
857892 use_dp_padding = use_dp_padding )
858893 else :
859894 with torch .cuda .stream (self .aux_stream ):
860895 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
861896 outputs_list [- 1 ],
897+ use_all_to_all ,
862898 all_rank_num_tokens = all_rank_num_tokens_list [- 1 ],
863899 use_dp_padding = use_dp_padding )
864900 with torch .cuda .stream (self .aux_stream ):
@@ -873,7 +909,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
873909 def alltoall_prepare_maybe_dispatch (
874910 self , all_rank_max_num_tokens : int , x : torch .Tensor ,
875911 token_selected_slots : torch .Tensor ,
876- token_final_scales : torch .Tensor ,
912+ token_final_scales : torch .Tensor , use_postquant_alltoall : bool ,
877913 local_statistic_tensor : Optional [torch .Tensor ]):
878914 top_k = self .routing_method .experts_per_token
879915
@@ -919,7 +955,7 @@ def alltoall_prepare_maybe_dispatch(
919955 gathered_token_final_scales , all_rank_max_num_tokens ,
920956 self .num_slots , top_k , self .ep_rank , self .ep_size )
921957
922- if not self . use_postquant_alltoall :
958+ if not use_postquant_alltoall :
923959 assert not isinstance (
924960 x , Fp4QuantizedTensor
925961 ), "pre-quant alltoall doesn't support fp4 tensor"
0 commit comments