@@ -200,7 +200,7 @@ async def wait_to_model_ready(self):
200200
201201 return
202202
203- async def add_req (self , group_req_indexes : GroupReqIndexes ):
203+ def add_req (self , group_req_indexes : GroupReqIndexes ):
204204 req_group = []
205205 for req_index in group_req_indexes .shm_req_indexes :
206206 req = self .shm_req_manager .get_req_obj_by_index (req_index )
@@ -211,6 +211,7 @@ async def add_req(self, group_req_indexes: GroupReqIndexes):
211211 logger .info (f"router recive req id { req .request_id } cost time { time .time () - req .start_time } s" )
212212 self .req_queue .extend (req_group )
213213 self .send_to_detokenization .send_pyobj (group_req_indexes , protocol = pickle .HIGHEST_PROTOCOL )
214+
214215 return
215216
216217 async def loop_for_fwd (
@@ -262,18 +263,18 @@ async def get_schedule_result(self, running_batch: Batch):
262263 if self .schedule_task is None :
263264
264265 def get_new_batch ():
265- current_waiting_num = None
266+ limit_router_queue_length = None
266267 if self .nnodes > 1 and self .args .dp == 1 :
267268 # 使用 all_reduce 获取最小值
268- current_waiting_num = len (self .req_queue .waiting_req_list )
269- current_waiting_num_tensor = torch .tensor (current_waiting_num , dtype = torch .int32 , device = "cpu" )
270- dist .all_reduce (current_waiting_num_tensor , op = dist .ReduceOp .MIN , group = self .mulitnode_group )
271- current_waiting_num = current_waiting_num_tensor .item ()
269+ limit_router_queue_length = len (self .req_queue .waiting_req_list )
270+ limit_router_queue_length_tensor = torch .tensor (limit_router_queue_length , dtype = torch .int32 , device = "cpu" )
271+ dist .all_reduce (limit_router_queue_length_tensor , op = dist .ReduceOp .MIN , group = self .mulitnode_group )
272+ limit_router_queue_length = limit_router_queue_length_tensor .item ()
272273
273274 self .overlap_event .wait (timeout = 0.020 )
274275 self .overlap_event .clear ()
275276 time .sleep (0.003 )
276- new_batch = self .req_queue .generate_new_batch (running_batch , current_waiting_num )
277+ new_batch = self .req_queue .generate_new_batch (running_batch , limit_router_queue_length )
277278 return new_batch
278279
279280 self .schedule_task = asyncio .get_running_loop ().run_in_executor (self .overlap_thread_pool , get_new_batch )
@@ -399,7 +400,7 @@ async def loop_for_netio_req(self):
399400 while True :
400401 recv_req : GroupReqIndexes = await self .recv_from_httpserver .recv_pyobj ()
401402 if isinstance (recv_req , GroupReqIndexes ):
402- await self .add_req (recv_req )
403+ self .add_req (recv_req )
403404 else :
404405 assert False , f"Error Req Inf { recv_req } "
405406
@@ -408,7 +409,6 @@ def clean_up(self):
408409
409410
410411def start_router_process (args , router_port , detokenization_port , model_rpc_ports , metric_port , pipe_writer ):
411-
412412 # 注册 graceful 退出的处理
413413 graceful_registry (inspect .currentframe ().f_code .co_name )
414414 start_parent_check_thread ()
0 commit comments