@@ -34,31 +34,27 @@ def __init__(self, num_experts=256, num_selected=8):
3434 def generate_balanced_tensor (self , num_tokens ):
3535 # Evenly distribute num_tokens to num_selected experts out of num_experts.
3636 # Note that the num_selected experts activated by a token cannot be repeated.
37- # Performance is not that important, as it is only activated in special scenarios.
38- tensor = torch .zeros ((num_tokens , self .num_selected ), dtype = torch .int , device = "cuda" )
37+ tensor = torch .empty ((num_tokens , self .num_selected ), dtype = torch .int , device = "cuda" )
3938 expert_load = torch .zeros (self .num_experts , dtype = torch .int , device = "cuda" )
4039
4140 for i in range (num_tokens ):
42- available_experts = torch .arange (self .num_experts , device = "cuda" )
43- selected = []
44- for _ in range (self .num_selected ):
45- current_load = expert_load [available_experts ]
46- min_load_indices = torch .where (current_load == current_load .min ())[0 ]
41+ selected_mask = torch .zeros (self .num_experts , dtype = torch .bool , device = "cuda" )
42+ for j in range (self .num_selected ):
43+ # Use a large value for already selected experts to exclude them
44+ load_view = torch .where (selected_mask , torch .iinfo (expert_load .dtype ).max , expert_load )
45+
46+ min_load_indices = torch .where (load_view == load_view .min ())[0 ]
47+
4748 if len (min_load_indices ) > 1 :
4849 # If there are multiple least-loaded experts, select one randomly
49- chosen_index = torch .randint (0 , len (min_load_indices ), (1 ,), device = "cuda" ).item ()
50- chosen_expert_index = min_load_indices [chosen_index ]
50+ rand_idx = torch .randint (0 , len (min_load_indices ), (1 ,), device = "cuda" ).item ()
51+ chosen_expert = min_load_indices [rand_idx ]
5152 else :
52- chosen_expert_index = min_load_indices [0 ]
53- chosen_expert = available_experts [chosen_expert_index ]
54- selected .append (chosen_expert )
55- # Remove the selected expert from the list of available experts
56- available_experts = torch .cat (
57- [available_experts [:chosen_expert_index ], available_experts [chosen_expert_index + 1 :]]
58- )
59- expert_load [chosen_expert ] += 1
53+ chosen_expert = min_load_indices [0 ]
6054
61- tensor [i ] = torch .tensor (selected , dtype = torch .int , device = "cuda" )
55+ tensor [i , j ] = chosen_expert
56+ expert_load [chosen_expert ] += 1
57+ selected_mask [chosen_expert ] = True
6258
6359 return tensor
6460
0 commit comments