@@ -12,7 +12,7 @@ class DpQueue:
1212 def __init__ (self , args , router , base_queue_class , dp_size_in_node ) -> None :
1313 self .dp_size_in_node = dp_size_in_node
1414 self .base_queue_class = base_queue_class
15- self .round_robin_dp_id = 0
15+ self .pre_select_dp_index = self . dp_size_in_node - 1
1616 from lightllm .server .router .manager import RouterManager
1717
1818 self .router : RouterManager = router
@@ -52,8 +52,8 @@ def append(self, req: Req):
5252 suggested_dp_index = req .sample_params .suggested_dp_index
5353 if suggested_dp_index >= self .dp_size_in_node or suggested_dp_index < 0 :
5454 logger .error (f"input req { req .request_id } dp index { suggested_dp_index } has error" )
55- suggested_dp_index = self .round_robin_dp_id
56- self .round_robin_dp_id = ( self . round_robin_dp_id + 1 ) % self . dp_size_in_node
55+ suggested_dp_index = self ._get_suggest_dp_index ()
56+ self .pre_select_dp_index = suggested_dp_index
5757 req .sample_params .suggested_dp_index = suggested_dp_index
5858 self .inner_queues [suggested_dp_index ].append (req )
5959 else :
@@ -62,13 +62,12 @@ def append(self, req: Req):
6262
6363 def extend (self , req_group : List [Req ]):
6464 # 同一个组的,要分配在同一个 dp 上,效率最高
65- index = self .round_robin_dp_id
66- self .round_robin_dp_id = (self .round_robin_dp_id + 1 ) % self .dp_size_in_node
65+ index = self ._get_suggest_dp_index ()
6766 for req in req_group :
6867 suggested_dp_index = req .sample_params .suggested_dp_index
6968 if suggested_dp_index >= self .dp_size_in_node or suggested_dp_index < 0 :
7069 logger .error (f"input req { req .request_id } dp index { suggested_dp_index } has error" )
71-
70+ self . pre_select_dp_index = index
7271 req .sample_params .suggested_dp_index = index
7372 self .inner_queues [index ].append (req )
7473 else :
@@ -94,3 +93,21 @@ def update_token_load(self, current_batch: Batch, force_update=False):
9493 self .router .shared_token_load .set_estimated_peak_token_count (estimated_peak_token_count , dp_index )
9594 self .router .shared_token_load .set_dynamic_max_load (dynamic_max_load , dp_index )
9695 return
96+
97+ def _get_suggest_dp_index (self ):
98+ min_length = min (len (queue .waiting_req_list ) for queue in self .inner_queues )
99+ select_dp_indexes = [
100+ i for i , queue in enumerate (self .inner_queues ) if len (queue .waiting_req_list ) == min_length
101+ ]
102+
103+ # multi thread safe keep
104+ if not select_dp_indexes :
105+ return random .randint (0 , self .dp_size_in_node - 1 )
106+
107+ # round_robin select.
108+ for i in range (self .dp_size_in_node ):
109+ next_dp_index = (self .pre_select_dp_index + i + 1 ) % self .dp_size_in_node
110+ if next_dp_index in select_dp_indexes :
111+ return next_dp_index
112+
113+ return random .choice (select_dp_indexes )
0 commit comments