@@ -108,8 +108,9 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por
108108 g_router_lock .obj = self .router_lock
109109
110110 # 调度和推理进行折叠使用的线程池
111- # self.overlap_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
112- # self.schedule_task = None
111+ self .schedule_new_batch : Batch = None
112+ self .schedule_lock = asyncio .Lock ()
113+ self .schedule_sem = asyncio .Semaphore (1 )
113114 return
114115
115116 async def wait_to_model_ready (self ):
@@ -285,34 +286,37 @@ async def loop_for_fwd(
285286 if self .running_batch is None :
286287 await asyncio .sleep (0.01 ) # 10ms
287288
288- def get_new_batch (self ):
289+ def generate_new_batch (self ):
289290 limit_router_queue_length = None
290291 if self .is_multinode_tp :
291292 # 使用 all_reduce 获取最小值
292293 limit_router_queue_length = len (self .req_queue .waiting_req_list )
293294 limit_router_queue_length_tensor = torch .tensor (limit_router_queue_length , dtype = torch .int32 , device = "cpu" )
294295 dist .all_reduce (limit_router_queue_length_tensor , op = dist .ReduceOp .MIN , group = self .mulitnode_group )
295296 limit_router_queue_length = limit_router_queue_length_tensor .item ()
296-
297- new_batch = self .req_queue .generate_new_batch (self .running_batch , limit_router_queue_length )
298- return new_batch
297+
298+ # 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。
299+ new_batch = self .req_queue .generate_new_batch (Batch .merge (self .running_batch , self .schedule_new_batch ), limit_router_queue_length )
300+ self .schedule_new_batch = Batch .merge (self .schedule_new_batch , new_batch )
301+ return
299302
300303 async def _step (self ):
301304 """
302305 事件处理循环
303306 """
304307 # 删除所有已经 finished 的 req
305308 # 当前无运行请求时
306- new_batch = None
307- if not self .batch_queue .empty ():
308- new_batch = self .batch_queue .get_nowait ()
309+ new_batch = self .schedule_new_batch
310+ self .schedule_new_batch = None
309311 if new_batch is not None :
310312 await self ._prefill_batch (new_batch )
313+ self .stats_tool .count_prompt_tokens (new_batch )
311314 self ._filter_runing_batch ()
312- if self .running_batch is None :
313- self .running_batch = new_batch
314- else :
315- self .running_batch .merge (new_batch )
315+ if not new_batch .is_clear ():
316+ if self .running_batch is None :
317+ self .running_batch = new_batch
318+ else :
319+ self .running_batch .merge (new_batch )
316320
317321 # Check if need pause some requests for decode.
318322 for dp_index in range (self .dp_size_in_node ):
@@ -391,17 +395,20 @@ def get_used_tokens(self, dp_index):
391395 async def loop_for_netio_req (self ):
392396 while True :
393397 try :
394- recv_req : GroupReqIndexes = self .recv_from_httpserver .recv_pyobj (zmq .NOBLOCK )
395- if isinstance (recv_req , GroupReqIndexes ):
396- self .add_req (recv_req )
397- else :
398- assert False , f"Error Req Inf { recv_req } "
398+ # 一次最多从 zmq 中取 24 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
399+ for _ in range (36 ):
400+ recv_req : GroupReqIndexes = self .recv_from_httpserver .recv_pyobj (zmq .NOBLOCK )
401+ if isinstance (recv_req , GroupReqIndexes ):
402+ self .add_req (recv_req )
403+ else :
404+ assert False , f"Error Req Inf { recv_req } "
399405 except zmq .ZMQError :
400- new_batch = self .get_new_batch ()
401- if new_batch is not None :
402- self .batch_queue .put_nowait (new_batch )
403- await asyncio .sleep (0.005 )
404- continue
406+ pass
407+
408+ # 调度新的 batch
409+
410+ self .generate_new_batch ()
411+ await asyncio .sleep (0.005 )
405412
406413 def clean_up (self ):
407414 return
@@ -440,8 +447,8 @@ def handle_exception(loop, context):
440447 loop = asyncio .new_event_loop ()
441448 loop .set_exception_handler (handle_exception )
442449 asyncio .set_event_loop (loop )
443- router .batch_queue = asyncio .Queue ()
444450
445451 loop .create_task (router .loop_for_fwd ())
446452 loop .run_until_complete (router .loop_for_netio_req ())
447453 return
454+
0 commit comments