33
44import threading
55
6+ from lightllm .utils .log_utils import init_logger
7+
8+ logger = init_logger (__name__ )
9+
610def singleton_threadsafe (cls ):
711 instances = {}
812 lock = threading .Lock ()
9-
1013 def get_instance (* args , ** kwargs ):
14+ # A key that includes the arguments is needed for parameter-dependent singletons.
15+ # Using a tuple of args and a frozenset of kwargs items makes it hashable.
16+ key = (cls , args , frozenset (kwargs .items ()))
1117 with lock :
12- if cls not in instances :
13- instances [cls ] = cls (* args , ** kwargs )
14- return instances [cls ]
18+ if key not in instances :
19+ instances [key ] = cls (* args , ** kwargs )
20+ return instances [key ]
1521 return get_instance
1622
1723@singleton_threadsafe
@@ -21,43 +27,38 @@ def __init__(self, num_experts=256, num_selected=8):
2127 self .num_experts = num_experts
2228 self .num_selected = num_selected
2329
24- def generate_balanced_tensor (self , length ):
25- # 初始化一个 length * 8 的全零张量,放置在 GPU 上
26- tensor = torch .zeros ((length , self .num_selected ), dtype = torch .int , device = 'cuda' )
27- # 初始化每个专家的负载计数
30+ def gemini_generate_balanced_tensor (self , length ):
31+ # A more performant way to generate a balanced tensor for expert selection.
32+ tensor = torch .empty ((length , self .num_selected ), dtype = torch .int , device = 'cuda' )
2833 expert_load = torch .zeros (self .num_experts , dtype = torch .int , device = 'cuda' )
2934
35+ expert_indices = torch .arange (self .num_experts , device = 'cuda' )
36+
3037 for i in range (length ):
31- available_experts = torch .arange (self .num_experts , device = 'cuda' )
32- selected = []
33- for _ in range (self .num_selected ):
34- # 计算每个可用专家的当前负载
35- current_load = expert_load [available_experts ]
36- # 选择负载最小的专家
37- min_load_indices = torch .where (current_load == current_load .min ())[0 ]
38- if len (min_load_indices ) > 1 :
39- # 如果有多个负载最小的专家,随机选择一个
40- chosen_index = torch .randint (0 , len (min_load_indices ), (1 ,), device = 'cuda' ).item ()
41- chosen_expert_index = min_load_indices [chosen_index ]
42- else :
43- chosen_expert_index = min_load_indices [0 ]
44- chosen_expert = available_experts [chosen_expert_index ]
45- selected .append (chosen_expert )
46- # 从可用专家列表中移除已选择的专家
47- available_experts = torch .cat (
48- [available_experts [:chosen_expert_index ], available_experts [chosen_expert_index + 1 :]])
49- # 更新该专家的负载
50- expert_load [chosen_expert ] += 1
51- tensor [i ] = torch .tensor (selected , dtype = torch .int , device = 'cuda' )
38+ # To break ties randomly when loads are equal, we can shuffle indices
39+ # of experts with the same load. A simple way is to shuffle all
40+ # indices and then sort by load.
41+ shuffled_indices = expert_indices [torch .randperm (self .num_experts , device = 'cuda' )]
42+ sorted_shuffled_indices = shuffled_indices [torch .argsort (expert_load [shuffled_indices ])]
43+
44+ # Select the top `num_selected` experts with the lowest load
45+ selected_experts = sorted_shuffled_indices [:self .num_selected ]
46+
47+ tensor [i ] = selected_experts
48+
49+ # Update loads for the selected experts using an efficient scatter_add
50+ expert_load .scatter_add_ (0 , selected_experts , torch .ones_like (selected_experts , dtype = torch .int ))
51+
5252 return tensor
5353
54- def get_balance_topk_ids (self , length ):
55- if self .balanced_tensors .get (length ) is not None :
56- #print("find length ", length)
57- return self .balanced_tensors [length ]
54+
55+ def get_balance_topk_ids (self , num_tokens ):
56+ if self .balanced_tensors .get (num_tokens ) is not None :
57+ #logger.info(f"find balanced tensor for num_tokens={num_tokens}")
58+ return self .balanced_tensors [num_tokens ]
5859 else :
59- #print( "generate length ", length )
60- tensor = self .generate_balanced_tensor (length )
61- self .balanced_tensors [length ] = tensor
60+ #logger.info(f "generate balanced tensor for num_tokens={num_tokens}" )
61+ tensor = self .generate_balanced_tensor (num_tokens )
62+ self .balanced_tensors [num_tokens ] = tensor
6263 return tensor
6364
0 commit comments