2525from lightllm .server .core .objs .io_objs import GroupReqObjs
2626from fastapi import Request
2727from lightllm .server .core .objs .shm_req_manager import ShmReqManager
28- from lightllm .server .core .objs .ordered_req_manager import OrderedRequestManager
2928from lightllm .utils .log_utils import init_logger
3029from lightllm .utils .envs_utils import get_env_start_args
3130from lightllm .server .metrics .manager import MetricClient
@@ -57,7 +56,8 @@ def __init__(
5756 self .waiting_objs = []
5857 self .child_node_lock = asyncio .Lock ()
5958 self .nnodes = args .nnodes
60- self .order_req_manager = OrderedRequestManager ()
59+ self .node_rank = args .node_rank
60+ self .transfer_lock = asyncio .Lock ()
6161 if args .nnodes > 1 :
6262 if args .node_rank == 0 :
6363 self .multinode_req_manager = []
@@ -155,12 +155,14 @@ def tokens(self, prompt, kwargs=None):
155155 async def loop_for_request (self ):
156156 assert self .args .node_rank > 0
157157 tasks = []
158+ self .request_order_queue = []
158159 while True :
159160 (
160161 prompt ,
161162 sampling_params ,
162163 multimodal_params ,
163164 ) = await self .multinode_req_manager .recv_pyobj ()
165+ self .request_order_queue .append (sampling_params .group_request_id )
164166 results_generator = self .generate (prompt , sampling_params , multimodal_params , None )
165167
166168 async def generate_wrapper (results_generator ):
@@ -186,12 +188,6 @@ async def generate(
186188 if self .pd_mode == NodeRole .NORMAL :
187189 if sampling_params .group_request_id == - 1 :
188190 group_request_id = self .id_gen .generate_id ()
189- for sender in self .multinode_req_manager :
190- sampling_params .group_request_id = group_request_id
191- sender .send_pyobj (
192- (prompt , sampling_params , multimodal_params ),
193- protocol = pickle .HIGHEST_PROTOCOL ,
194- )
195191 else :
196192 group_request_id = sampling_params .group_request_id
197193 sampling_params .group_request_id = group_request_id
@@ -247,10 +243,9 @@ async def generate(
247243 req_status = ReqStatus (group_request_id , multimodal_params , req_objs , start_time )
248244 self .req_id_to_out_inf [group_request_id ] = req_status
249245
250- # 将请求转发给其他节点
251- await self .order_req_manager .add_request (req_status .group_req_objs )
252- async with self .order_req_manager .lock :
253- await self .transfer_to_next_module ()
246+ await self .transfer_to_next_module_or_node (
247+ prompt , sampling_params , multimodal_params , req_status .group_req_objs
248+ )
254249
255250 results_generator = self ._wait_to_token_package (
256251 start_time ,
@@ -339,10 +334,43 @@ async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params:
339334
340335 return prompt_ids
341336
337+ async def transfer_to_next_module_or_node (
338+ self ,
339+ prompt : str ,
340+ sampling_params : SamplingParams ,
341+ multimodal_params : MultimodalParams ,
342+ group_req_objs : Optional [GroupReqObjs ] = None ,
343+ ):
344+ if self .nnodes > 1 and self .node_rank == 0 and self .args .dp == 1 :
345+ async with self .transfer_lock :
346+ for sender in self .multinode_req_manager :
347+ sender .send_pyobj (
348+ (prompt , sampling_params , multimodal_params ),
349+ protocol = pickle .HIGHEST_PROTOCOL ,
350+ )
351+ await self .transfer_to_next_module (group_req_objs )
352+ return
353+
354+ if self .nnodes > 1 and self .node_rank > 0 and self .args .dp == 1 :
355+ while True :
356+ if self .request_order_queue and self .request_order_queue [0 ] != group_req_objs .group_req_id :
357+ await asyncio .sleep (0.002 )
358+ continue
359+ else :
360+ async with self .transfer_lock :
361+ await self .transfer_to_next_module (group_req_objs )
362+ self .request_order_queue .pop (0 )
363+ break
364+ return
365+
366+ await self .transfer_to_next_module (group_req_objs )
367+ return
368+
342369 async def transfer_to_next_module (
343370 self ,
371+ group_req_objs : Optional [GroupReqObjs ] = None ,
344372 ):
345- group_req_objs : GroupReqObjs = await self . order_req_manager . get_next_request ()
373+
346374 if self .pd_mode == NodeRole .P :
347375 if self .enable_multimodal :
348376 self .send_to_visual .send_pyobj (
0 commit comments