1- import copy
21import time
3- import uuid
42import uvloop
53import asyncio
64import torch
7- import rpyc
85import pickle
9- import threading
106import inspect
117
128asyncio .set_event_loop_policy (uvloop .EventLoopPolicy ())
13- import concurrent .futures
149import zmq
1510import zmq .asyncio
1611import torch .multiprocessing as mp
1712import torch .distributed as dist
1813import multiprocessing
1914from typing import Dict , List , Optional
20- from .batch import Batch
15+ from .batch import Batch , Req
2116from .model_infer .model_rpc import start_model_process , ModelRpcClient
2217from .req_queue import build_req_queue
23- from lightllm .utils .infer_utils import calculate_time
2418from lightllm .server .core .objs .io_objs import GroupReqIndexes
2519from lightllm .server .core .objs import ShmReqManager , StartArgs
2620from .dynamic_prompt .radix_cache import RadixCacheReadOnlyClient
@@ -109,8 +103,7 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por
109103
110104 # 调度和推理进行折叠使用的线程池
111105 self .schedule_new_batch : Batch = None
112- self .schedule_lock = asyncio .Lock ()
113- self .schedule_sem = asyncio .Semaphore (1 )
106+ self .schedule_event = asyncio .Event ()
114107 return
115108
116109 async def wait_to_model_ready (self ):
@@ -222,7 +215,6 @@ def add_req(self, group_req_indexes: GroupReqIndexes):
222215 logger .info (f"router recive req id { req .request_id } cost time { time .time () - req .start_time } s" )
223216 self .req_queue .extend (req_group )
224217 self .send_to_detokenization .send_pyobj (group_req_indexes , protocol = pickle .HIGHEST_PROTOCOL )
225-
226218 return
227219
228220 async def loop_for_fwd (
@@ -312,14 +304,10 @@ async def _step(self):
312304 ):
313305 new_batch = self .schedule_new_batch
314306 self .schedule_new_batch = None
307+ self ._add_new_batch_to_running_batch (new_batch = new_batch )
315308 await self ._prefill_batch (new_batch )
316309 self .stats_tool .count_prompt_tokens (new_batch )
317- self ._filter_runing_batch ()
318- if not new_batch .is_clear ():
319- if self .running_batch is None :
320- self .running_batch = new_batch
321- else :
322- self .running_batch .merge (new_batch )
310+ self ._filter_reqs_from_running_batch ()
323311 self .has_wait_tokens = 0
324312
325313 # Check if need pause some requests for decode.
@@ -335,38 +323,43 @@ async def _step(self):
335323
336324 # Decode
337325 self .stats_tool .count_output_tokens (self .running_batch )
338- await self ._decode_batch (self . running_batch )
339- self ._filter_runing_batch ()
326+ await self ._decode_batch ()
327+ self ._filter_reqs_from_running_batch ()
340328 self .has_wait_tokens += 1
341329 return
342330
343331 async def _prefill_batch (self , batch : Batch ):
332+ # 添加新请求
344333 reqs = [r .to_router_rpc_obj () for r in batch .reqs ]
345334 await self .model_rpc_client .prefill (reqs )
346- batch .filter_out_finished_req (self .shm_req_manager )
347335 self ._send_detokenization_pack ()
348-
349336 logger .debug (f"Prefill Batch: { batch .simple_log ()} \n " )
350337 return
351338
352- async def _decode_batch (self , batch : Batch ):
339+ async def _decode_batch (self ):
340+ self .schedule_event .set ()
353341 await self .model_rpc_client .decode ()
354- # 在 self.is_multinode_and_multidp 为 True 时,传入的 batch 对象可能为 None。
355- if batch is not None :
356- batch .filter_out_finished_req (self .shm_req_manager )
357-
358342 self ._send_detokenization_pack ()
359343 return
360344
361- async def _pause_reqs (self , pasue_reqs ):
345+ async def _pause_reqs (self , pasue_reqs : List [ Req ] ):
362346 pasue_req_ids = [r .request_id for r in pasue_reqs ]
363347 await self .model_rpc_client .pause_reqs (pasue_req_ids )
364348 return
365349
366- def _filter_runing_batch (self ):
367- if self .running_batch is not None and self .running_batch .is_clear ():
368- self .running_batch = None
369- return
350+ def _add_new_batch_to_running_batch (self , new_batch : Batch ):
351+ if self .running_batch is None :
352+ self .running_batch = new_batch
353+ else :
354+ self .running_batch .merge (new_batch )
355+ return
356+
357+ def _filter_reqs_from_running_batch (self ):
358+ if self .running_batch is not None :
359+ self .running_batch .filter_out_finished_req (self .shm_req_manager )
360+ if self .running_batch .is_clear ():
361+ self .running_batch = None
362+ return
370363
371364 def _can_decode (self , batch : Batch , dp_index : int ):
372365 if self .is_pd_run_mode or self .is_safe_schedule or batch is None :
@@ -394,21 +387,35 @@ def get_used_tokens(self, dp_index):
394387 return self .max_total_token_num - self .read_only_statics_mem_manager .get_unrefed_token_num (dp_index )
395388
396389 async def loop_for_netio_req (self ):
390+ recv_max_count = 66
391+
397392 while True :
398393 try :
399- # 一次最多从 zmq 中取 24 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
400- for _ in range (36 ):
394+ # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
395+ for _ in range (recv_max_count ):
401396 recv_req : GroupReqIndexes = self .recv_from_httpserver .recv_pyobj (zmq .NOBLOCK )
402397 if isinstance (recv_req , GroupReqIndexes ):
403398 self .add_req (recv_req )
404399 else :
405400 assert False , f"Error Req Inf { recv_req } "
401+
402+ # 当队列中存在较多的请求时,将一次接受的数量上调
403+ recv_max_count = min (int (recv_max_count * 1.3 ), 300 )
404+
406405 except zmq .ZMQError :
406+ # 当队列已经开始清空的时候,将一次接受的数量下调
407+ recv_max_count = 66
408+
409+ try :
410+ await asyncio .wait_for (self .schedule_event .wait (), timeout = 0.02 )
411+ except asyncio .TimeoutError :
407412 pass
408413
409- # 调度新的 batch
410- self .generate_new_batch ()
411- await asyncio .sleep (0.005 )
414+ if self .schedule_event .is_set ():
415+ self .generate_new_batch ()
416+ self .schedule_event .clear ()
417+
418+ return
412419
413420 def clean_up (self ):
414421 return
0 commit comments