1- import copy
21import time
3- import uuid
42import uvloop
53import asyncio
64import torch
7- import rpyc
85import pickle
9- import threading
106import inspect
117
128asyncio .set_event_loop_policy (uvloop .EventLoopPolicy ())
13- import concurrent .futures
149import zmq
1510import zmq .asyncio
1611import torch .multiprocessing as mp
1712import torch .distributed as dist
1813import multiprocessing
1914from typing import Dict , List , Optional
20- from .batch import Batch
15+ from .batch import Batch , Req
2116from .model_infer .model_rpc import start_model_process , ModelRpcClient
2217from .req_queue import build_req_queue
23- from lightllm .utils .infer_utils import calculate_time
2418from lightllm .server .core .objs .io_objs import GroupReqIndexes
2519from lightllm .server .core .objs import ShmReqManager , StartArgs
2620from .dynamic_prompt .radix_cache import RadixCacheReadOnlyClient
@@ -108,8 +102,8 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por
108102 g_router_lock .obj = self .router_lock
109103
110104 # 调度和推理进行折叠使用的线程池
111- # self.overlap_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
112- # self.schedule_task = None
105+ self .schedule_new_batch : Batch = None
106+ self .schedule_event = asyncio . Event ()
113107 return
114108
115109 async def wait_to_model_ready (self ):
@@ -140,8 +134,6 @@ async def wait_to_model_ready(self):
140134 self .model_rpc_servers .append (rpc_model )
141135
142136 self .model_rpc_client = ModelRpcClient (
143- model_infer_servers = self .model_rpc_servers ,
144- world_size = self .world_size ,
145137 rpc_event = self .rpc_event ,
146138 rpc_finished_event = self .rpc_finished_event ,
147139 )
@@ -223,7 +215,6 @@ def add_req(self, group_req_indexes: GroupReqIndexes):
223215 logger .info (f"router recive req id { req .request_id } cost time { time .time () - req .start_time } s" )
224216 self .req_queue .extend (req_group )
225217 self .send_to_detokenization .send_pyobj (group_req_indexes , protocol = pickle .HIGHEST_PROTOCOL )
226-
227218 return
228219
229220 async def loop_for_fwd (
@@ -285,7 +276,7 @@ async def loop_for_fwd(
285276 if self .running_batch is None :
286277 await asyncio .sleep (0.01 ) # 10ms
287278
288- def get_new_batch (self ):
279+ def generate_new_batch (self ):
289280 limit_router_queue_length = None
290281 if self .is_multinode_tp :
291282 # 使用 all_reduce 获取最小值
@@ -294,25 +285,30 @@ def get_new_batch(self):
294285 dist .all_reduce (limit_router_queue_length_tensor , op = dist .ReduceOp .MIN , group = self .mulitnode_group )
295286 limit_router_queue_length = limit_router_queue_length_tensor .item ()
296287
297- new_batch = self .req_queue .generate_new_batch (self .running_batch , limit_router_queue_length )
298- return new_batch
288+ # 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。
289+ new_batch = self .req_queue .generate_new_batch (
290+ Batch .merge_two_batch (self .running_batch , self .schedule_new_batch ), limit_router_queue_length
291+ )
292+ self .schedule_new_batch = Batch .merge_two_batch (self .schedule_new_batch , new_batch )
293+ return
299294
300295 async def _step (self ):
301296 """
302297 事件处理循环
303298 """
304- # 删除所有已经 finished 的 req
305- # 当前无运行请求时
306- new_batch = None
307- if not self .batch_queue .empty ():
308- new_batch = self .batch_queue .get_nowait ()
309- if new_batch is not None :
299+ # 判断是否有新请求加入推理
300+ # 激进调度满足,有新的推理batch就需要进行加入。
301+ # 或者延迟step的步数满足了当前条件,也需要进行新的推理batch的加入。
302+ if (self .schedule_new_batch is not None ) and (
303+ (not self .args .disable_aggressive_schedule ) or (self .has_wait_tokens >= self .max_wait_tokens )
304+ ):
305+ new_batch = self .schedule_new_batch
306+ self .schedule_new_batch = None
307+ self ._add_new_batch_to_running_batch (new_batch = new_batch )
310308 await self ._prefill_batch (new_batch )
311- self ._filter_runing_batch ()
312- if self .running_batch is None :
313- self .running_batch = new_batch
314- else :
315- self .running_batch .merge (new_batch )
309+ self .stats_tool .count_prompt_tokens (new_batch )
310+ self ._filter_reqs_from_running_batch ()
311+ self .has_wait_tokens = 0
316312
317313 # Check if need pause some requests for decode.
318314 for dp_index in range (self .dp_size_in_node ):
@@ -327,41 +323,43 @@ async def _step(self):
327323
328324 # Decode
329325 self .stats_tool .count_output_tokens (self .running_batch )
330- await self ._decode_batch (self .running_batch )
331- if self .world_size // self .nnodes == 1 :
332- # node_world_size == 1 时,协程不会让出来,导致无法调度新请求,所以sleep 1ms,可以修改一下
333- await asyncio .sleep (0.001 )
334- self ._filter_runing_batch ()
326+ await self ._decode_batch ()
327+ self ._filter_reqs_from_running_batch ()
335328 self .has_wait_tokens += 1
336329 return
337330
338331 async def _prefill_batch (self , batch : Batch ):
332+ # 添加新请求
339333 reqs = [r .to_router_rpc_obj () for r in batch .reqs ]
340334 await self .model_rpc_client .prefill (reqs )
341- batch .filter_out_finished_req (self .shm_req_manager )
342335 self ._send_detokenization_pack ()
343-
344336 logger .debug (f"Prefill Batch: { batch .simple_log ()} \n " )
345337 return
346338
347- async def _decode_batch (self , batch : Batch ):
339+ async def _decode_batch (self ):
340+ self .schedule_event .set ()
348341 await self .model_rpc_client .decode ()
349- # 在 self.is_multinode_and_multidp 为 True 时,传入的 batch 对象可能为 None。
350- if batch is not None :
351- batch .filter_out_finished_req (self .shm_req_manager )
352-
353342 self ._send_detokenization_pack ()
354343 return
355344
356- async def _pause_reqs (self , pasue_reqs ):
345+ async def _pause_reqs (self , pasue_reqs : List [ Req ] ):
357346 pasue_req_ids = [r .request_id for r in pasue_reqs ]
358347 await self .model_rpc_client .pause_reqs (pasue_req_ids )
359348 return
360349
361- def _filter_runing_batch (self ):
362- if self .running_batch is not None and self .running_batch .is_clear ():
363- self .running_batch = None
364- return
350+ def _add_new_batch_to_running_batch (self , new_batch : Batch ):
351+ if self .running_batch is None :
352+ self .running_batch = new_batch
353+ else :
354+ self .running_batch .merge (new_batch )
355+ return
356+
357+ def _filter_reqs_from_running_batch (self ):
358+ if self .running_batch is not None :
359+ self .running_batch .filter_out_finished_req (self .shm_req_manager )
360+ if self .running_batch .is_clear ():
361+ self .running_batch = None
362+ return
365363
366364 def _can_decode (self , batch : Batch , dp_index : int ):
367365 if self .is_pd_run_mode or self .is_safe_schedule or batch is None :
@@ -389,19 +387,35 @@ def get_used_tokens(self, dp_index):
389387 return self .max_total_token_num - self .read_only_statics_mem_manager .get_unrefed_token_num (dp_index )
390388
391389 async def loop_for_netio_req (self ):
390+ recv_max_count = 66
391+
392392 while True :
393393 try :
394- recv_req : GroupReqIndexes = self .recv_from_httpserver .recv_pyobj (zmq .NOBLOCK )
395- if isinstance (recv_req , GroupReqIndexes ):
396- self .add_req (recv_req )
397- else :
398- assert False , f"Error Req Inf { recv_req } "
394+ # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
395+ for _ in range (recv_max_count ):
396+ recv_req : GroupReqIndexes = self .recv_from_httpserver .recv_pyobj (zmq .NOBLOCK )
397+ if isinstance (recv_req , GroupReqIndexes ):
398+ self .add_req (recv_req )
399+ else :
400+ assert False , f"Error Req Inf { recv_req } "
401+
402+ # 当队列中存在较多的请求时,将一次接受的数量上调
403+ recv_max_count = min (int (recv_max_count * 1.3 ), 300 )
404+
399405 except zmq .ZMQError :
400- new_batch = self .get_new_batch ()
401- if new_batch is not None :
402- self .batch_queue .put_nowait (new_batch )
403- await asyncio .sleep (0.005 )
404- continue
406+ # 当队列已经开始清空的时候,将一次接受的数量下调
407+ recv_max_count = 66
408+
409+ try :
410+ await asyncio .wait_for (self .schedule_event .wait (), timeout = 0.02 )
411+ except asyncio .TimeoutError :
412+ pass
413+
414+ if self .schedule_event .is_set ():
415+ self .generate_new_batch ()
416+ self .schedule_event .clear ()
417+
418+ return
405419
406420 def clean_up (self ):
407421 return
@@ -440,7 +454,6 @@ def handle_exception(loop, context):
440454 loop = asyncio .new_event_loop ()
441455 loop .set_exception_handler (handle_exception )
442456 asyncio .set_event_loop (loop )
443- router .batch_queue = asyncio .Queue ()
444457
445458 loop .create_task (router .loop_for_fwd ())
446459 loop .run_until_complete (router .loop_for_netio_req ())
0 commit comments