22from typing import List
33from ..batch import Batch , Req
44from lightllm .server .router .req_queue .base_queue import BaseQueue
5+ from lightllm .server .router .req_queue .dp_balancer import get_dp_balancer
56from lightllm .common .basemodel .infer_lock import g_router_lock
67from lightllm .utils .log_utils import init_logger
78
@@ -12,14 +13,13 @@ class DpQueue:
1213 def __init__ (self , args , router , base_queue_class , dp_size_in_node ) -> None :
1314 self .dp_size_in_node = dp_size_in_node
1415 self .base_queue_class = base_queue_class
15- self .pre_select_dp_index = self .dp_size_in_node - 1
1616 from lightllm .server .router .manager import RouterManager
1717
1818 self .router : RouterManager = router
1919 self .inner_queues : List [BaseQueue ] = [
2020 base_queue_class (args , router , dp_index , dp_size_in_node ) for dp_index in range (self .dp_size_in_node )
2121 ]
22-
22+ self . dp_balancer = get_dp_balancer ( args , dp_size_in_node , self . inner_queues )
2323 return
2424
2525 def get_dp_queue (self , dp_index : int ):
@@ -49,8 +49,7 @@ def append(self, req: Req):
4949 suggested_dp_index = req .sample_params .suggested_dp_index
5050 if suggested_dp_index >= self .dp_size_in_node or suggested_dp_index < 0 :
5151 logger .warning (f"input req { req .request_id } dp index { suggested_dp_index } is invalid" )
52- suggested_dp_index = self ._get_suggest_dp_index ()
53- self .pre_select_dp_index = suggested_dp_index
52+ suggested_dp_index = self .dp_balancer .get_suggest_dp_index ()
5453 req .sample_params .suggested_dp_index = suggested_dp_index
5554 self .inner_queues [suggested_dp_index ].append (req )
5655 else :
@@ -59,12 +58,11 @@ def append(self, req: Req):
5958
6059 def extend (self , req_group : List [Req ]):
6160 # 同一个组的,要分配在同一个 dp 上,效率最高
62- index = self ._get_suggest_dp_index ()
61+ index = self .dp_balancer . get_suggest_dp_index ()
6362 for req in req_group :
6463 suggested_dp_index = req .sample_params .suggested_dp_index
6564 if suggested_dp_index >= self .dp_size_in_node or suggested_dp_index < 0 :
6665 logger .warning (f"input req { req .request_id } dp index { suggested_dp_index } is invalid" )
67- self .pre_select_dp_index = index
6866 req .sample_params .suggested_dp_index = index
6967 self .inner_queues [index ].append (req )
7068 else :
@@ -87,21 +85,3 @@ def update_token_load(self, current_batch: Batch, force_update=False):
8785 self .router .shared_token_load .set_estimated_peak_token_count (estimated_peak_token_count , dp_index )
8886 self .router .shared_token_load .set_dynamic_max_load (dynamic_max_load , dp_index )
8987 return
90-
91- def _get_suggest_dp_index (self ):
92- min_length = min (len (queue .waiting_req_list ) for queue in self .inner_queues )
93- select_dp_indexes = [
94- i for i , queue in enumerate (self .inner_queues ) if len (queue .waiting_req_list ) == min_length
95- ]
96-
97- # multi thread safe keep
98- if not select_dp_indexes :
99- return random .randint (0 , self .dp_size_in_node - 1 )
100-
101- # round_robin select.
102- for i in range (self .dp_size_in_node ):
103- next_dp_index = (self .pre_select_dp_index + i + 1 ) % self .dp_size_in_node
104- if next_dp_index in select_dp_indexes :
105- return next_dp_index
106-
107- return random .choice (select_dp_indexes )
0 commit comments