@@ -50,11 +50,12 @@ def __init__(
5050 context = zmq .asyncio .Context (2 )
5151 self .send_to_router = context .socket (zmq .PUSH )
5252 self .send_to_router .connect (f"{ args .zmq_mode } 127.0.0.1:{ router_port } " )
53-
53+
5454 self .multinode_req_manager = None
5555 self .child_node_events = {}
5656 self .waiting_objs = []
5757 self .child_node_lock = asyncio .Lock ()
58+ self .nnodes = args .nnodes
5859 if args .nnodes > 1 :
5960 if args .node_rank == 0 :
6061 self .multinode_req_manager = []
@@ -64,12 +65,16 @@ def __init__(
6465 context = zmq .asyncio .Context (2 )
6566 self .multinode_req_manager .append (context .socket (zmq .PUSH ))
6667 self .multinode_req_manager [- 1 ].connect (f"tcp://{ child_ip } :{ args .multinode_httpmanager_port } " )
67- logger .info (f"HttpServerManager connected to child node at { child_ip } :{ args .multinode_httpmanager_port } " )
68+ logger .info (
69+ f"HttpServerManager connected to child node at { child_ip } :{ args .multinode_httpmanager_port } "
70+ )
6871 else :
6972 context = zmq .asyncio .Context (2 )
7073 self .multinode_req_manager = context .socket (zmq .PULL )
7174 self .multinode_req_manager .bind (f"tcp://*:{ args .multinode_httpmanager_port } " )
72- logger .info (f"HttpServerManager listening for child node requests on *:{ args .multinode_httpmanager_port } " )
75+ logger .info (
76+ f"HttpServerManager listening for child node requests on *:{ args .multinode_httpmanager_port } "
77+ )
7378
7479 self .enable_multimodal = enable_multimodal
7580 if self .enable_multimodal :
@@ -145,16 +150,26 @@ async def _release_multimodal_resources(self, multimodal_params: MultimodalParam
145150 def tokens (self , prompt ):
146151 prompt_ids = self .tokenizer .encode (prompt )
147152 return len (prompt_ids )
148-
153+
149154 async def loop_for_request (self ):
150155 assert self .args .node_rank > 0
151156 tasks = []
152157 while True :
153- request_id , prompt , sampling_params , multimodal_params , request_headers = await self .multinode_req_manager .recv_pyobj ()
154- results_generator = self .generate (prompt , sampling_params , multimodal_params , None , request_headers , request_id )
158+ (
159+ request_id ,
160+ prompt ,
161+ sampling_params ,
162+ multimodal_params ,
163+ request_headers ,
164+ ) = await self .multinode_req_manager .recv_pyobj ()
165+ results_generator = self .generate (
166+ prompt , sampling_params , multimodal_params , None , request_headers , request_id
167+ )
168+
155169 async def generate_wrapper (results_generator ):
156170 async for _ , _ , _ , _ in results_generator :
157171 pass
172+
158173 tasks .append (asyncio .create_task (generate_wrapper (results_generator )))
159174 # cleanup
160175 while len (tasks ) > 0 and tasks [0 ].done ():
@@ -166,7 +181,7 @@ async def generate(
166181 sampling_params : SamplingParams ,
167182 multimodal_params : MultimodalParams ,
168183 request : Request ,
169- request_headers = None ,
184+ request_headers = None ,
170185 multinode_remote_request_id : Optional [int ] = None ,
171186 ) -> Tuple [int , str , dict , FinishStatus ]:
172187 start_time = time .time ()
@@ -178,7 +193,10 @@ async def generate(
178193 if multinode_remote_request_id is None :
179194 group_request_id = self .id_gen .generate_id ()
180195 for sender in self .multinode_req_manager :
181- sender .send_pyobj ((group_request_id , prompt , sampling_params , multimodal_params , request_headers ), protocol = pickle .HIGHEST_PROTOCOL )
196+ sender .send_pyobj (
197+ (group_request_id , prompt , sampling_params , multimodal_params , request_headers ),
198+ protocol = pickle .HIGHEST_PROTOCOL ,
199+ )
182200 else :
183201 group_request_id = multinode_remote_request_id
184202 sampling_params .group_request_id = group_request_id
@@ -238,8 +256,12 @@ async def generate(
238256 await self .transfer_to_next_module (req_status .group_req_objs )
239257
240258 results_generator = self ._wait_to_token_package (
241- start_time , prompt_ids , group_request_id , sampling_params , req_status , # request,
242- request_headers ,
259+ start_time ,
260+ prompt_ids ,
261+ group_request_id ,
262+ sampling_params ,
263+ req_status ,
264+ request ,
243265 )
244266 async for sub_req_id , request_output , metadata , finish_status in results_generator :
245267 # p d 模式下,将 token 数据放入到转发队列中
@@ -368,8 +390,7 @@ async def _wait_to_token_package(
368390 group_request_id : int ,
369391 sampling_params : SamplingParams ,
370392 req_status : "ReqStatus" ,
371- request_headers ,
372- # request: Request,
393+ request : Request ,
373394 ):
374395
375396 event = req_status .event
@@ -385,10 +406,9 @@ async def _wait_to_token_package(
385406 except asyncio .TimeoutError :
386407 pass
387408
388- # TODO: abort() for multinode
389- # if request is not None and await request.is_disconnected():
390- # await self.abort(group_request_id)
391- # raise Exception(f"req_id {group_request_id} disconnected")
409+ if request is not None and await request .is_disconnected () and self .nnodes == 1 :
410+ await self .abort (group_request_id )
411+ raise Exception (f"req_id { group_request_id } disconnected" )
392412
393413 async with req_status .lock :
394414 event .clear ()
@@ -416,13 +436,12 @@ async def _wait_to_token_package(
416436 unfinished_count -= 1
417437
418438 # 所有子请求完成后,就删除占用的资源
419- if unfinished_count == 0 :
439+ if unfinished_count == 0 and request is not None :
420440 total_cost_time_ms = (time .time () - start_time ) * 1000
421441 mean_per_token_cost_time_ms = (total_cost_time_ms - first_token_cost_ms ) / out_token_counter
422442 self .per_token_costs .add (mean_per_token_cost_time_ms )
423- x_request_id = request_headers .get ("X-Request-Id" , "" ) if request_headers is not None else ""
424- x_session_id = request_headers .get ("X-Session-Id" , "" ) if request_headers is not None else ""
425-
443+ x_request_id = request .headers .get ("X-Request-Id" , "" ) if request is not None else ""
444+ x_session_id = request .headers .get ("X-Session-Id" , "" ) if request is not None else ""
426445 prompt_cache_ratio = prompt_cache_len / prompt_tokens
427446 self .metric_client .histogram_observe ("lightllm_cache_length" , prompt_cache_len )
428447 self .metric_client .histogram_observe ("lightllm_cache_ratio" , prompt_cache_ratio )
@@ -506,7 +525,7 @@ async def handle_loop(self):
506525 if self .pd_mode .is_P_or_D ():
507526 self .forwarding_queue = AsyncQueue ()
508527 asyncio .create_task (self .pd_handle_loop ())
509-
528+
510529 if self .args .node_rank > 0 :
511530 asyncio .create_task (self .loop_for_request ())
512531
0 commit comments