33
44import threading
55
6+ from lightllm .utils .log_utils import init_logger
7+
8+ logger = init_logger (__name__ )
9+
10+
611def singleton_threadsafe (cls ):
712 instances = {}
813 lock = threading .Lock ()
914
1015 def get_instance (* args , ** kwargs ):
16+ # A key that includes the arguments is needed for parameter-dependent singletons.
17+ # Using a tuple of args and a frozenset of kwargs items makes it hashable.
18+ key = (cls , args , frozenset (kwargs .items ()))
1119 with lock :
12- if cls not in instances :
13- instances [cls ] = cls (* args , ** kwargs )
14- return instances [cls ]
20+ if key not in instances :
21+ instances [key ] = cls (* args , ** kwargs )
22+ return instances [key ]
23+
1524 return get_instance
1625
26+
1727@singleton_threadsafe
1828class BalancedTensor :
1929 def __init__ (self , num_experts = 256 , num_selected = 8 ):
@@ -22,42 +32,34 @@ def __init__(self, num_experts=256, num_selected=8):
2232 self .num_selected = num_selected
2333
2434 def generate_balanced_tensor (self , length ):
25- # 初始化一个 length * 8 的全零张量,放置在 GPU 上
26- tensor = torch .zeros (( length , self .num_selected ) , dtype = torch .int , device = ' cuda' )
27- # 初始化每个专家的负载计数
28- expert_load = torch .zeros (self .num_experts , dtype = torch . int , device = ' cuda' )
35+ tensor = torch . empty (( length , self . num_selected ), dtype = torch . int , device = "cuda" )
36+ expert_load = torch .zeros (self .num_experts , dtype = torch .int , device = " cuda" )
37+
38+ expert_indices = torch .arange (self .num_experts , device = " cuda" )
2939
3040 for i in range (length ):
31- available_experts = torch .arange (self .num_experts , device = 'cuda' )
32- selected = []
33- for _ in range (self .num_selected ):
34- # 计算每个可用专家的当前负载
35- current_load = expert_load [available_experts ]
36- # 选择负载最小的专家
37- min_load_indices = torch .where (current_load == current_load .min ())[0 ]
38- if len (min_load_indices ) > 1 :
39- # 如果有多个负载最小的专家,随机选择一个
40- chosen_index = torch .randint (0 , len (min_load_indices ), (1 ,), device = 'cuda' ).item ()
41- chosen_expert_index = min_load_indices [chosen_index ]
42- else :
43- chosen_expert_index = min_load_indices [0 ]
44- chosen_expert = available_experts [chosen_expert_index ]
45- selected .append (chosen_expert )
46- # 从可用专家列表中移除已选择的专家
47- available_experts = torch .cat (
48- [available_experts [:chosen_expert_index ], available_experts [chosen_expert_index + 1 :]])
49- # 更新该专家的负载
50- expert_load [chosen_expert ] += 1
51- tensor [i ] = torch .tensor (selected , dtype = torch .int , device = 'cuda' )
41+ # To break ties randomly when loads are equal, we can shuffle indices
42+ # of experts with the same load. A simple way is to shuffle all
43+ # indices and then sort by load.
44+ shuffled_indices = expert_indices [torch .randperm (self .num_experts , device = "cuda" )]
45+ sorted_shuffled_indices = shuffled_indices [torch .argsort (expert_load [shuffled_indices ])]
46+
47+ # Select the top `num_selected` experts with the lowest load
48+ selected_experts = sorted_shuffled_indices [: self .num_selected ]
49+
50+ tensor [i ] = selected_experts
51+
52+ # Update loads for the selected experts using an efficient scatter_add
53+ expert_load .scatter_add_ (0 , selected_experts , torch .ones_like (selected_experts , dtype = torch .int ))
54+
5255 return tensor
5356
54- def get_balance_topk_ids (self , length ):
55- if self .balanced_tensors .get (length ) is not None :
56- #print( "find length ", length )
57- return self .balanced_tensors [length ]
57+ def get_balance_topk_ids (self , num_tokens ):
58+ if self .balanced_tensors .get (num_tokens ) is not None :
59+ # logger.info(f "find balanced tensor for num_tokens={num_tokens}" )
60+ return self .balanced_tensors [num_tokens ]
5861 else :
59- #print( "generate length ", length )
60- tensor = self .generate_balanced_tensor (length )
61- self .balanced_tensors [length ] = tensor
62+ # logger.info(f "generate balanced tensor for num_tokens={num_tokens}" )
63+ tensor = self .generate_balanced_tensor (num_tokens )
64+ self .balanced_tensors [num_tokens ] = tensor
6265 return tensor
63-
0 commit comments