@@ -55,11 +55,11 @@ def __init__(
5555 return
5656
5757 async def register_pd (self , pd_info_json , websocket ):
58- await self .pd_manager .register_pd (pd_info_json , websocket )
58+ self .pd_manager .register_pd (pd_info_json , websocket )
5959 return
6060
6161 async def remove_pd (self , pd_info_json ):
62- await self .pd_manager .remove_pd (pd_info_json )
62+ self .pd_manager .remove_pd (pd_info_json )
6363 return
6464
6565 async def update_req_status (self , upkv_status : UpKVStatus ):
@@ -92,7 +92,7 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar
9292 async def select_p_d_node (
9393 self , prompt : Union [str , List [int ]], sampling_params : SamplingParams , multimodal_params : MultimodalParams
9494 ) -> Tuple [PD_Client_Obj , PD_Client_Obj ]:
95- return await self .pd_manager .select_p_d_node (prompt , sampling_params , multimodal_params )
95+ return self .pd_manager .select_p_d_node (prompt , sampling_params , multimodal_params )
9696
9797 async def generate (
9898 self ,
@@ -411,7 +411,7 @@ def __init__(self, args):
411411 )
412412 return
413413
414- async def register_pd (self , pd_info_json , websocket ):
414+ def register_pd (self , pd_info_json , websocket ):
415415 pd_client = PD_Client_Obj (** pd_info_json )
416416 pd_client .websocket = websocket
417417 self .url_to_pd_nodes [pd_client .client_ip_port ] = pd_client
@@ -425,19 +425,19 @@ async def register_pd(self, pd_info_json, websocket):
425425 else :
426426 assert False , f"mode must in ['prefill', 'decode'], but get { pd_client .mode } "
427427
428- await self .selector .update_nodes (self .prefill_nodes , self .decode_nodes )
428+ self .selector .update_nodes (self .prefill_nodes , self .decode_nodes )
429429
430430 logger .info (f"mode: { pd_client .mode } url: { pd_client .client_ip_port } registed" )
431431 return
432432
433- async def remove_pd (self , pd_info_json ):
433+ def remove_pd (self , pd_info_json ):
434434 pd_client = PD_Client_Obj (** pd_info_json )
435435
436436 self .url_to_pd_nodes .pop (pd_client .client_ip_port , None )
437437 self .prefill_nodes = [e for e in self .prefill_nodes if e .client_ip_port != pd_client .client_ip_port ]
438438 self .decode_nodes = [e for e in self .decode_nodes if e .client_ip_port != pd_client .client_ip_port ]
439439
440- await self .selector .update_nodes (self .prefill_nodes , self .decode_nodes )
440+ self .selector .update_nodes (self .prefill_nodes , self .decode_nodes )
441441
442442 logger .info (f"mode: { pd_client .mode } url: { pd_client .client_ip_port } removed" )
443443 return
@@ -461,13 +461,8 @@ def update_node_load_info(self, load_info: Optional[dict]):
461461 logger .warning (f"udpate node load info failed, load_info: { load_info } error: { str (e )} " )
462462 return
463463
464- def get_predict_node_infos (self ):
465- """获取所有节点的预测负载信息"""
466- return self .node_info_recorder .get_predict_node_infos ()
467-
468- async def select_p_d_node (
464+ def select_p_d_node (
469465 self , prompt : Union [str , List [int ]], sampling_params : SamplingParams , multimodal_params : MultimodalParams
470466 ) -> Tuple [PD_Client_Obj , PD_Client_Obj ]:
471- p_node , d_node = await self .selector .select_p_d_node (prompt , sampling_params , multimodal_params )
472- self .node_info_recorder .update_predict_node_info (p_node , d_node , prompt , sampling_params , multimodal_params )
467+ p_node , d_node = self .selector .select_p_d_node (prompt , sampling_params , multimodal_params )
473468 return p_node , d_node
0 commit comments