1818from lightllm .server .core .objs .io_objs import GroupReqIndexes , AbortedReqCmd
1919from lightllm .server .core .objs import ShmReqManager , StartArgs
2020from .dynamic_prompt .radix_cache import RadixCacheReadOnlyClient
21- from .stats import Stats
2221from .shm_reqs_io_buffer import ShmReqsIOBuffer
2322from lightllm .utils .log_utils import init_logger , log_time_ready
2423from lightllm .server .router .token_load import TokenLoad
@@ -45,6 +44,8 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por
4544 # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容
4645 self .dp_size_in_node = max (1 , args .dp // self .nnodes )
4746 self .is_multinode_tp = args .nnodes > 1 and args .dp == 1
47+ self .is_multinode_tp_master = self .is_multinode_tp and args .node_rank == 0
48+ self .is_multinode_tp_slave = self .is_multinode_tp and args .node_rank != 0
4849 self .is_multinode_and_multidp = args .nnodes > 1 and args .dp > 1
4950 # 判断是否是保守调度,保守调度不会发生暂停 req 的情况,但是有些场景可能影响吞吐
5051 self .is_safe_schedule = args .router_token_ratio == 0.0
@@ -254,6 +255,8 @@ async def _step(self):
254255 """
255256 事件处理循环
256257 """
258+ # 接受新请求,并尝试调度
259+ await self ._recv_new_reqs_and_schedule ()
257260 # 判断是否有新请求加入推理
258261 # 激进调度满足,有新的推理batch就需要进行加入。
259262 # 或者延迟step的步数满足了当前条件,也需要进行新的推理batch的加入。
@@ -357,44 +360,96 @@ def _add_req(self, group_req_indexes: GroupReqIndexes):
357360 return
358361
359362 def _generate_new_batch (self ):
360- limit_router_queue_length = None
361- if self .is_multinode_tp :
362- # 使用 all_reduce 获取最小值
363- limit_router_queue_length = len (self .req_queue .waiting_req_list )
364- limit_router_queue_length_tensor = torch .tensor (limit_router_queue_length , dtype = torch .int32 , device = "cpu" )
365- dist .all_reduce (limit_router_queue_length_tensor , op = dist .ReduceOp .MIN , group = self .mulitnode_group )
366- limit_router_queue_length = limit_router_queue_length_tensor .item ()
367-
368363 # 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。
369364 new_batch = self .req_queue .generate_new_batch (
370- Batch .merge_two_batch (self .running_batch , self .schedule_new_batch ), limit_router_queue_length
365+ Batch .merge_two_batch (self .running_batch , self .schedule_new_batch )
371366 )
372367 self .schedule_new_batch = Batch .merge_two_batch (self .schedule_new_batch , new_batch )
373368 return
374369
375- async def loop_for_netio_req (self ):
376- recv_max_count = 64
370+ def _multinode_tp_generate_new_batch (self ):
371+ dist . barrier ( group = self . mulitnode_group )
377372
378- while True :
379- try :
380- # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
381- for _ in range (recv_max_count ):
382- recv_req : GroupReqIndexes = self .recv_from_httpserver .recv_pyobj (zmq .NOBLOCK )
383- if isinstance (recv_req , GroupReqIndexes ):
384- self ._add_req (recv_req )
385- else :
386- assert False , f"Error Req Inf { recv_req } "
387-
388- # 当队列中存在较多的请求时,将一次接受的数量上调
389- recv_max_count = min (int (recv_max_count * 1.3 ), 256 )
390-
391- except zmq .ZMQError :
392- # 当队列已经开始清空的时候,将一次接受的数量下调
393- recv_max_count = 64
373+ # 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。
374+ if self .is_multinode_tp_master :
375+ new_batch = self .req_queue .generate_new_batch (
376+ Batch .merge_two_batch (self .running_batch , self .schedule_new_batch )
377+ )
378+ if new_batch is not None :
379+ req_ids = [req .request_id for req in new_batch .reqs ]
380+ else :
381+ req_ids = []
382+ dist .broadcast_object_list ([len (req_ids )], src = 0 , group = self .mulitnode_group )
383+ dist .broadcast_object_list (req_ids , src = 0 , group = self .mulitnode_group )
384+ req_id_select_mark = [1 for _ in range (len (req_ids ))]
385+ req_id_select_mark = torch .tensor (req_id_select_mark , dtype = torch .int32 , device = "cpu" )
386+ dist .all_reduce (req_id_select_mark , op = dist .ReduceOp .MIN , group = self .mulitnode_group )
387+ back_req_list = []
388+ for req_id , select in zip (req_ids , req_id_select_mark .numpy ()):
389+ if select == 0 :
390+ req = new_batch .pop_req (req_id )
391+ back_req_list .append (req )
392+ self .req_queue .waiting_req_list = back_req_list + self .req_queue .waiting_req_list
393+ if new_batch .is_clear ():
394+ new_batch = None
395+ else :
396+ req_nums = [None ]
397+ dist .broadcast_object_list (req_nums , src = 0 , group = self .mulitnode_group )
398+ req_num = req_nums [0 ]
399+ req_ids = [None for _ in range (req_num )]
400+ dist .broadcast_object_list (req_ids , src = 0 , group = self .mulitnode_group )
401+ all_req_id_set = set ([req .request_id for req in self .req_queue .waiting_req_list ])
402+ req_id_select_mark = []
403+ for req_id in req_ids :
404+ req_id_select_mark .append (1 if req_id in all_req_id_set else 0 )
405+ req_id_select_mark = torch .tensor (req_id_select_mark , dtype = torch .int32 , device = "cpu" )
406+ dist .all_reduce (req_id_select_mark , op = dist .ReduceOp .MIN , group = self .mulitnode_group )
407+ select_req_ids = []
408+ for req_id , select in zip (req_ids , req_id_select_mark .numpy ()):
409+ if select == 1 :
410+ select_req_ids .append (req_id )
411+
412+ select_reqs = []
413+ for req_id in select_req_ids :
414+ for req in self .req_queue .waiting_req_list :
415+ if req .request_id == req_id :
416+ select_reqs .append (req )
417+
418+ for req in select_reqs :
419+ self .req_queue .waiting_req_list .remove (req )
420+ if select_reqs :
421+ new_batch = Batch (- 1 , reqs = select_reqs , dp_size_in_node = self .dp_size_in_node )
422+ else :
423+ new_batch = None
394424
395- await asyncio .sleep (0.02 )
425+ self .schedule_new_batch = Batch .merge_two_batch (self .schedule_new_batch , new_batch )
426+
427+ dist .barrier (group = self .mulitnode_group )
428+ return
429+
430+ async def _recv_new_reqs_and_schedule (self ):
431+ if not hasattr (self , "recv_max_count" ):
432+ self .recv_max_count = 64
433+
434+ try :
435+ # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
436+ for _ in range (self .recv_max_count ):
437+ recv_req : GroupReqIndexes = self .recv_from_httpserver .recv_pyobj (zmq .NOBLOCK )
438+ if isinstance (recv_req , GroupReqIndexes ):
439+ self ._add_req (recv_req )
440+ else :
441+ assert False , f"Error Req Inf { recv_req } "
396442
397- # 只有当推理侧没有发生暂停的时候,才执行新的调度
443+ # 当队列中存在较多的请求时,将一次接受的数量上调
444+ self .recv_max_count = min (int (self .recv_max_count * 1.3 ), 256 )
445+
446+ except zmq .ZMQError :
447+ # 当队列已经开始清空的时候,将一次接受的数量下调
448+ self .recv_max_count = 64
449+
450+ if self .is_multinode_tp :
451+ self ._multinode_tp_generate_new_batch ()
452+ else :
398453 if self ._get_paused_req_num () == 0 :
399454 self ._generate_new_batch ()
400455 return
@@ -436,6 +491,5 @@ def handle_exception(loop, context):
436491 raise
437492
438493 pipe_writer .send ("init ok" )
439- loop .create_task (router .loop_for_fwd ())
440- loop .run_until_complete (router .loop_for_netio_req ())
494+ loop .run_until_complete (router .loop_for_fwd ())
441495 return
0 commit comments