@@ -591,22 +591,20 @@ def determine_expert_map(
591591 if ep_size == 1 :
592592 return (global_num_experts , None )
593593
594- local_num_experts = global_num_experts // ep_size
594+ # Distribute experts as evenly as possible to each rank.
595+ base_experts = global_num_experts // ep_size
596+ remainder = global_num_experts % ep_size
597+ if ep_rank < remainder :
598+ local_num_experts = base_experts + 1
599+ else :
600+ local_num_experts = base_experts
595601
596602 # Create a tensor of size num_experts filled with -1
597603 expert_map = torch .full ((global_num_experts , ), - 1 , dtype = torch .int32 )
598604 # Create a expert map for the local experts
599- if ep_rank < (ep_size - 1 ):
600- # Each non-last rank gets local_num_experts experts.
601- expert_map [ep_rank * local_num_experts :
602- (ep_rank + 1 ) * local_num_experts ] = \
603- torch .arange (0 , local_num_experts , dtype = torch .int32 )
604- else :
605- # All remaining experts are assigned to the last rank.
606- local_num_experts = (global_num_experts - ep_rank * local_num_experts )
607-
608- expert_map [- local_num_experts :] = \
609- torch .arange (0 , local_num_experts , dtype = torch .int32 )
605+ start_idx = ep_rank * base_experts + min (ep_rank , remainder )
606+ expert_map [start_idx :start_idx + local_num_experts ] = torch .arange (
607+ 0 , local_num_experts , dtype = torch .int32 )
610608 return (local_num_experts , expert_map )
611609
612610
0 commit comments