1212import pickle
1313
1414asyncio .set_event_loop_policy (uvloop .EventLoopPolicy ())
15- from typing import Union , List , Tuple , Dict
15+ from typing import Union , List , Tuple , Dict , Optional
1616from lightllm .server .core .objs import FinishStatus
1717from ..pd_io_struct import PD_Client_Obj , UpKVStatus , ObjType
1818from lightllm .server .core .objs import SamplingParams
2525from lightllm .utils .statics_utils import MovingAverage
2626from lightllm .server .httpserver .manager import AsyncQueue
2727from lightllm .utils .error_utils import ServerBusyError
28+ from .pd_selector import create_selector
2829
2930logger = init_logger (__name__ )
3031
@@ -38,9 +39,8 @@ def __init__(
3839 self .args = args
3940 self .metric_client = MetricClient (metric_port )
4041 self .id_gen = ReqIDGenerator ()
41- self .prefill_nodes : List [PD_Client_Obj ] = []
42- self .decode_nodes : List [PD_Client_Obj ] = []
43- self .url_to_pd_nodes : Dict [str , PD_Client_Obj ] = {}
42+
43+ self .pd_manager = PDManager (args )
4444
4545 self .req_id_to_out_inf : Dict [int , ReqStatus ] = {}
4646 self .infos_queues = None # 这个需要延迟初始化,否则使用的loop不对
@@ -52,30 +52,11 @@ def __init__(
5252 return
5353
5454 async def register_pd (self , pd_info_json , websocket ):
55- pd_client = PD_Client_Obj (** pd_info_json )
56- pd_client .websocket = websocket
57- self .url_to_pd_nodes [pd_client .client_ip_port ] = pd_client
58- if pd_client .mode == "prefill" :
59- self .prefill_nodes = [e for e in self .prefill_nodes if e .client_ip_port != pd_client .client_ip_port ]
60- self .prefill_nodes .append (pd_client )
61- elif pd_client .mode == "decode" :
62- self .decode_nodes = [e for e in self .decode_nodes if e .client_ip_port != pd_client .client_ip_port ]
63- self .decode_nodes .append (pd_client )
64- else :
65- assert False
66-
67- logger .info (f"mode: { pd_client .mode } url: { pd_client .client_ip_port } registed" )
55+ self .pd_manager .register_pd (pd_info_json , websocket )
6856 return
6957
7058 async def remove_pd (self , pd_info_json ):
71- pd_client = PD_Client_Obj (** pd_info_json )
72- try :
73- del self .url_to_pd_nodes [pd_client .client_ip_port ]
74- except :
75- pass
76- self .prefill_nodes = [e for e in self .prefill_nodes if e .client_ip_port != pd_client .client_ip_port ]
77- self .decode_nodes = [e for e in self .decode_nodes if e .client_ip_port != pd_client .client_ip_port ]
78- logger .info (f"mode: { pd_client .mode } url: { pd_client .client_ip_port } removed" )
59+ self .pd_manager .remove_pd (pd_info_json )
7960 return
8061
8162 async def update_req_status (self , upkv_status : UpKVStatus ):
@@ -108,11 +89,7 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar
10889 async def select_p_d_node (
10990 self , prompt : Union [str , List [int ]], sampling_params : SamplingParams , multimodal_params : MultimodalParams
11091 ) -> Tuple [PD_Client_Obj , PD_Client_Obj ]:
111- import random
112-
113- p_node = random .choice (self .prefill_nodes )
114- d_node = random .choice (self .decode_nodes )
115- return p_node , d_node
92+ return self .pd_manager .select_p_d_node (prompt , sampling_params , multimodal_params )
11693
11794 async def generate (
11895 self ,
@@ -264,7 +241,7 @@ async def _wait_to_token_package(
264241 request : Request ,
265242 ):
266243 out_token_counter = 0
267- first_token_cost_ms = sys . float_info . max
244+ first_token_cost_ms = float ( "inf" )
268245 group_request_id = sampling_params .group_request_id
269246 unfinished_count = sampling_params .best_of
270247 is_first_token = True
@@ -368,7 +345,10 @@ async def handle_loop(self):
368345 try :
369346 for obj in objs :
370347 if obj [0 ] == ObjType .TOKEN_PACKS :
371- for sub_req_id , text , metadata , finish_status in obj [1 ]:
348+ token_list , node_load_info = obj [1 ], obj [2 ]
349+ self .pd_manager .update_node_load_info (node_load_info )
350+
351+ for sub_req_id , text , metadata , finish_status in token_list :
372352 finish_status : FinishStatus = finish_status
373353 group_req_id = convert_sub_id_to_group_id (sub_req_id )
374354 try :
@@ -415,3 +395,69 @@ async def pop_all_tokens(self):
415395 ans = self .out_token_info_list .copy ()
416396 self .out_token_info_list .clear ()
417397 return ans
398+
399+
400+ class PDManager :
401+ def __init__ (self , args ):
402+ self .args = args
403+ self .prefill_nodes : List [PD_Client_Obj ] = []
404+ self .decode_nodes : List [PD_Client_Obj ] = []
405+ self .url_to_pd_nodes : Dict [str , PD_Client_Obj ] = {}
406+ self .selector = create_selector (args .select_p_d_node_strategy , self )
407+ return
408+
409+ def register_pd (self , pd_info_json , websocket ):
410+ pd_client = PD_Client_Obj (** pd_info_json )
411+ pd_client .websocket = websocket
412+ self .url_to_pd_nodes [pd_client .client_ip_port ] = pd_client
413+
414+ if pd_client .mode == "prefill" :
415+ self .prefill_nodes = [e for e in self .prefill_nodes if e .client_ip_port != pd_client .client_ip_port ]
416+ self .prefill_nodes .append (pd_client )
417+ elif pd_client .mode == "decode" :
418+ self .decode_nodes = [e for e in self .decode_nodes if e .client_ip_port != pd_client .client_ip_port ]
419+ self .decode_nodes .append (pd_client )
420+ else :
421+ assert False , f"mode must in ['prefill', 'decode'], but get { pd_client .mode } "
422+
423+ self .selector .update_nodes (self .prefill_nodes , self .decode_nodes )
424+
425+ logger .info (f"mode: { pd_client .mode } url: { pd_client .client_ip_port } registed" )
426+ return
427+
428+ def remove_pd (self , pd_info_json ):
429+ pd_client = PD_Client_Obj (** pd_info_json )
430+
431+ self .url_to_pd_nodes .pop (pd_client .client_ip_port , None )
432+ self .prefill_nodes = [e for e in self .prefill_nodes if e .client_ip_port != pd_client .client_ip_port ]
433+ self .decode_nodes = [e for e in self .decode_nodes if e .client_ip_port != pd_client .client_ip_port ]
434+
435+ self .selector .update_nodes (self .prefill_nodes , self .decode_nodes )
436+
437+ logger .info (f"mode: { pd_client .mode } url: { pd_client .client_ip_port } removed" )
438+ return
439+
440+ def update_node_load_info (self , load_info : Optional [dict ]):
441+ """更新节点负载信息
442+ load_info: 节点负载信息字典,内容格式如下,可以为 None
443+ {
444+ "total_token_usage_rate": xxxx,
445+ "client_ip_port": xxxx,
446+ }
447+ """
448+ try :
449+ if load_info is None :
450+ return
451+ client_ip_port = load_info ["client_ip_port" ]
452+ total_token_usage_rate = load_info ["total_token_usage_rate" ]
453+ pd_client = self .url_to_pd_nodes .get (client_ip_port )
454+ pd_client .run_status .total_token_usage_rate = total_token_usage_rate
455+ except BaseException as e :
456+ logger .warning (f"udpate node load info failed, load_info: { load_info } error: { str (e )} " )
457+ return
458+
459+ def select_p_d_node (
460+ self , prompt : Union [str , List [int ]], sampling_params : SamplingParams , multimodal_params : MultimodalParams
461+ ) -> Tuple [PD_Client_Obj , PD_Client_Obj ]:
462+ p_node , d_node = self .selector .select_p_d_node (prompt , sampling_params , multimodal_params )
463+ return p_node , d_node
0 commit comments