1515
1616logger = init_logger (__name__ )
1717
18+
1819async def timer_log (manager : HttpServerManager ):
1920 while True :
2021 await asyncio .sleep (30 )
@@ -32,13 +33,13 @@ async def pd_handle_loop(manager: HttpServerManager):
3233
3334 asyncio .create_task (timer_log (manager ))
3435
35- id_to_handle_task :Dict [int , asyncio .Task ] = {}
36+ id_to_handle_task : Dict [int , asyncio .Task ] = {}
3637
3738 while True :
3839 try :
3940 id_to_pd_master_obj = await _get_pd_master_objs (manager .args )
4041 logger .info (f"get pd_master_objs { id_to_pd_master_obj } " )
41-
42+
4243 if id_to_pd_master_obj is not None :
4344 for node_id , pd_master_obj in id_to_handle_task .items ():
4445 if node_id not in id_to_pd_master_obj :
@@ -51,7 +52,7 @@ async def pd_handle_loop(manager: HttpServerManager):
5152 id_to_handle_task [node_id ] = asyncio .create_task (_pd_handle_task (manager , pd_master_obj ))
5253
5354 await asyncio .sleep (30 )
54-
55+
5556 except Exception as e :
5657 logger .exception (str (e ))
5758 await asyncio .sleep (10 )
@@ -70,7 +71,7 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O
7071 try :
7172 uri = f"ws://{ pd_master_obj .host_ip_port } /pd_register"
7273 async with websockets .connect (uri , max_queue = (2048 * 1024 , 2048 * 1023 )) as websocket :
73-
74+
7475 sock = websocket .transport .get_extra_info ("socket" )
7576 sock .setsockopt (socket .IPPROTO_TCP , socket .TCP_NODELAY , 1 )
7677
@@ -83,35 +84,35 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O
8384 "mode" : manager .pd_mode .value ,
8485 "start_args" : args_dict ,
8586 }
86-
87+
8788 await websocket .send (json .dumps (regist_json ))
8889 logger .info (f"Sent registration JSON: { regist_json } " )
89-
90+
9091 # 转发任务
91- forwarding_tokens_task = asyncio .create_task (
92- _up_tokens_to_pd_master (forwarding_queue , websocket )
93- )
94-
92+ forwarding_tokens_task = asyncio .create_task (_up_tokens_to_pd_master (forwarding_queue , websocket ))
93+
9594 # 接收 pd master 发来的请求,并推理后,将生成的token转发回pd master。
9695 while True :
9796 recv_bytes = await websocket .recv ()
9897 obj = pickle .loads (recv_bytes )
9998 if obj [0 ] == ObjType .REQ :
10099 prompt , sampling_params , multimodal_params = obj [1 ]
101- asyncio .create_task (_pd_process_generate (manager , prompt , sampling_params , multimodal_params , forwarding_queue ))
100+ asyncio .create_task (
101+ _pd_process_generate (manager , prompt , sampling_params , multimodal_params , forwarding_queue )
102+ )
102103 elif obj [0 ] == ObjType .ABORT :
103104 group_req_id = obj [1 ]
104105 await manager .abort (group_req_id )
105106 else :
106107 logger .error (f"recevie error obj { str (obj )} " )
107-
108+
108109 except asyncio .CancelledError :
109110 # 如果任务被取消,则退出循环
110111 logger .warning (f"forwarding_tokens_task { pd_master_obj } cancelled" )
111112 if forwarding_tokens_task is not None :
112113 forwarding_tokens_task .cancel ()
113114 return
114-
115+
115116 except Exception as e :
116117 logger .error ("connetion to pd_master has error" )
117118 logger .exception (str (e ))
@@ -122,7 +123,7 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O
122123 logger .info ("reconnection to pd_master" )
123124
124125
125- async def _get_pd_master_objs (args )-> Optional [Dict [int , PD_Master_Obj ]]:
126+ async def _get_pd_master_objs (args ) -> Optional [Dict [int , PD_Master_Obj ]]:
126127 """
127128 get_pd_master_objs 主要负责从 pd master 获取所有的pd master对象。
128129 """
@@ -135,15 +136,15 @@ async def _get_pd_master_objs(args)->Optional[Dict[int, PD_Master_Obj]]:
135136 ans = dict ()
136137 ans [0 ] = PD_Master_Obj (node_id = 0 , host_ip_port = f"{ args .pd_master_ip } :{ args .pd_master_port } " )
137138 return ans
138-
139+
139140 # 使用 config_server 服务来发现所有的 pd_master 节点。
140141 uri = f"ws://{ args .config_server_host } :{ args .config_server_port } /registered_objects"
141142
142143 try :
143144 async with httpx .AsyncClient () as client :
144145 response = await client .get (uri )
145146 if response .status_code == 200 :
146- base64data = response .json ()["data" ]
147+ base64data = response .json ()["data" ]
147148 id_to_pd_master_obj = pickle .loads (base64 .b64decode (base64data ))
148149 return id_to_pd_master_obj
149150 else :
@@ -154,8 +155,11 @@ async def _get_pd_master_objs(args)->Optional[Dict[int, PD_Master_Obj]]:
154155 await asyncio .sleep (10 )
155156 return None
156157
158+
157159# 触发推理的task
158- async def _pd_process_generate (manager : HttpServerManager , prompt , sampling_params , multimodal_params , forwarding_queue :AsyncQueue ):
160+ async def _pd_process_generate (
161+ manager : HttpServerManager , prompt , sampling_params , multimodal_params , forwarding_queue : AsyncQueue
162+ ):
159163 try :
160164 async for sub_req_id , request_output , metadata , finish_status in manager .generate (
161165 prompt , sampling_params , multimodal_params , None
@@ -175,4 +179,3 @@ async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket):
175179 handle_list = await forwarding_queue .wait_to_get_all_data ()
176180 if handle_list :
177181 await websocket .send (pickle .dumps ((ObjType .TOKEN_PACKS , handle_list )))
178-
0 commit comments