@@ -430,14 +430,15 @@ def is_post_quant_all2all_supported(self):
430430 return False
431431
432432 def forward_chunk (
433- self ,
434- x : Union [torch .Tensor , Fp4QuantizedTensor ],
435- router_logits : torch .Tensor ,
436- use_all_to_all : bool ,
437- output_dtype : Optional [torch .dtype ] = None ,
438- all_rank_num_tokens : Optional [List [int ]] = None ,
439- use_dp_padding : Optional [bool ] = None ,
440- repeating_info : Tuple = (True , True ),
433+ self ,
434+ x : Union [torch .Tensor , Fp4QuantizedTensor ],
435+ router_logits : torch .Tensor ,
436+ use_all_to_all : bool ,
437+ output_dtype : Optional [torch .dtype ] = None ,
438+ all_rank_num_tokens : Optional [List [int ]] = None ,
439+ use_dp_padding : Optional [bool ] = None ,
440+ repeating_info : Tuple = (True , True ),
441+ alltoall_result_do_sum : bool = True ,
441442 ) -> torch .Tensor :
442443 all_rank_max_num_tokens = max (all_rank_num_tokens )
443444 if isinstance (x , Fp4QuantizedTensor ):
@@ -452,7 +453,7 @@ def forward_chunk(
452453 self .layer_load_balancer .start_wait_gpu_stage ()
453454
454455 if not use_all_to_all or self .alltoall_method_type != AlltoallMethodType .MNNVL :
455- pass
456+ alltoall_result_do_sum = True
456457
457458 weight_dtype = self .w3_w1_weight .dtype
458459
@@ -719,7 +720,8 @@ def forward_chunk(
719720 if self .enable_dummy_allreduce :
720721 self .dummy_allreduce ()
721722 final_hidden_states = self .alltoall_combine (
722- final_hidden_states , alltoall_info , token_count )
723+ final_hidden_states , alltoall_info , token_count ,
724+ alltoall_result_do_sum )
723725 elif self .alltoall_method_type == AlltoallMethodType .DeepEP :
724726 final_hidden_states = self .unpad_tensors (
725727 padded , final_hidden_states )
@@ -764,6 +766,7 @@ def forward_impl(
764766 output_dtype : Optional [torch .dtype ] = None ,
765767 all_rank_num_tokens : Optional [List [int ]] = None ,
766768 use_dp_padding : Optional [bool ] = None ,
769+ alltoall_result_do_sum : bool = True ,
767770 ** kwargs ,
768771 ) -> torch .Tensor :
769772 assert all_rank_num_tokens is not None
@@ -791,7 +794,8 @@ def forward_impl(
791794 output_dtype ,
792795 all_rank_num_tokens = all_rank_num_tokens_padded ,
793796 use_dp_padding = use_dp_padding ,
794- repeating_info = (is_first_call , is_last_call ))
797+ repeating_info = (is_first_call , is_last_call ),
798+ alltoall_result_do_sum = alltoall_result_do_sum )
795799 outputs = self .reducescatter_or_allreduce (
796800 outputs ,
797801 use_all_to_all ,
@@ -849,7 +853,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
849853 all_rank_num_tokens = all_rank_num_tokens_list [
850854 idx_chunk ],
851855 use_dp_padding = use_dp_padding ,
852- repeating_info = (is_first_call , is_last_call ))
856+ repeating_info = (is_first_call , is_last_call ),
857+ alltoall_result_do_sum = alltoall_result_do_sum )
853858 if idx_chunk > 0 :
854859 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
855860 outputs_list [- 1 ],
@@ -865,7 +870,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
865870 all_rank_num_tokens = all_rank_num_tokens_list [
866871 idx_chunk ],
867872 use_dp_padding = use_dp_padding ,
868- repeating_info = (is_first_call , is_last_call ))
873+ repeating_info = (is_first_call , is_last_call ),
874+ alltoall_result_do_sum = alltoall_result_do_sum )
869875 with torch .cuda .stream (self .aux_stream ):
870876 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
871877 outputs_list [- 1 ],
@@ -879,7 +885,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
879885 router_logits ,
880886 use_all_to_all ,
881887 all_rank_num_tokens = all_rank_num_tokens_list [idx_chunk ],
882- repeating_info = (is_first_call , is_last_call ))
888+ repeating_info = (is_first_call , is_last_call ),
889+ alltoall_result_do_sum = alltoall_result_do_sum )
883890
884891 outputs_list .append (outputs )
885892 if not use_all_to_all :
@@ -935,7 +942,8 @@ def alltoall_dispatch(self, x: torch.Tensor, x_sf: Optional[torch.Tensor],
935942 return x , x_sf , token_selected_slots , token_final_scales
936943
937944 def alltoall_combine (self , final_hidden_states : torch .Tensor ,
938- alltoall_info : MoEAlltoallInfo , token_count : int ):
945+ alltoall_info : MoEAlltoallInfo , token_count : int ,
946+ alltoall_result_do_sum : bool ):
939947 top_k = self .routing_method .experts_per_token
940948 if isinstance (final_hidden_states , list ):
941949 final_hidden_states = final_hidden_states [0 ]
@@ -948,7 +956,7 @@ def alltoall_combine(self, final_hidden_states: torch.Tensor,
948956 top_k = top_k ,
949957 token_count = token_count ,
950958 use_low_precision_combine = self .use_low_precision_combine ,
951- do_reduce = False )
959+ do_reduce = alltoall_result_do_sum )
952960
953961 return final_hidden_states
954962
0 commit comments