@@ -45,6 +45,8 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por
4545 # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容
4646 self .dp_size_in_node = max (1 , args .dp // self .nnodes )
4747 self .is_multinode_tp = args .nnodes > 1 and args .dp == 1
48+ self .is_multinode_tp_master = self .is_multinode_tp and args .node_rank == 0
49+ self .is_multinode_tp_slave = self .is_multinode_tp and args .node_rank != 0
4850 self .is_multinode_and_multidp = args .nnodes > 1 and args .dp > 1
4951 # 判断是否是保守调度,保守调度不会发生暂停 req 的情况,但是有些场景可能影响吞吐
5052 self .is_safe_schedule = args .router_token_ratio == 0.0
@@ -359,21 +361,73 @@ def _add_req(self, group_req_indexes: GroupReqIndexes):
359361 return
360362
361363 def _generate_new_batch (self ):
362- limit_router_queue_length = None
363- if self .is_multinode_tp :
364- # 使用 all_reduce 获取最小值
365- limit_router_queue_length = len (self .req_queue .waiting_req_list )
366- limit_router_queue_length_tensor = torch .tensor (limit_router_queue_length , dtype = torch .int32 , device = "cpu" )
367- dist .all_reduce (limit_router_queue_length_tensor , op = dist .ReduceOp .MIN , group = self .mulitnode_group )
368- limit_router_queue_length = limit_router_queue_length_tensor .item ()
369-
370364 # 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。
371365 new_batch = self .req_queue .generate_new_batch (
372- Batch .merge_two_batch (self .running_batch , self .schedule_new_batch ), limit_router_queue_length
366+ Batch .merge_two_batch (self .running_batch , self .schedule_new_batch )
373367 )
374368 self .schedule_new_batch = Batch .merge_two_batch (self .schedule_new_batch , new_batch )
375369 return
376370
371+ def _multinode_tp_generate_new_batch (self ):
372+ dist .barrier (group = self .mulitnode_group )
373+
374+ # 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。
375+ if self .is_multinode_tp_master :
376+ new_batch = self .req_queue .generate_new_batch (
377+ Batch .merge_two_batch (self .running_batch , self .schedule_new_batch )
378+ )
379+ if new_batch is not None :
380+ req_ids = [req .request_id for req in new_batch .reqs ]
381+ else :
382+ req_ids = []
383+ dist .broadcast_object_list ([len (req_ids )], src = 0 , group = self .mulitnode_group )
384+ dist .broadcast_object_list (req_ids , src = 0 , group = self .mulitnode_group )
385+ req_id_select_mark = [1 for _ in range (len (req_ids ))]
386+ req_id_select_mark = torch .tensor (req_id_select_mark , dtype = torch .int32 , device = "cpu" )
387+ dist .all_reduce (req_id_select_mark , op = dist .ReduceOp .MIN , group = self .mulitnode_group )
388+ back_req_list = []
389+ for req_id , select in zip (req_ids , req_id_select_mark .numpy ()):
390+ if select == 0 :
391+ req = new_batch .pop_req (req_id )
392+ back_req_list .append (req )
393+ self .req_queue .waiting_req_list = back_req_list + self .req_queue .waiting_req_list
394+ if new_batch .is_clear ():
395+ new_batch = None
396+ else :
397+ req_nums = [None ]
398+ dist .broadcast_object_list (req_nums , src = 0 , group = self .mulitnode_group )
399+ req_num = req_nums [0 ]
400+ req_ids = [None for _ in range (req_num )]
401+ dist .broadcast_object_list (req_ids , src = 0 , group = self .mulitnode_group )
402+ all_req_id_set = set ([req .request_id for req in self .req_queue .waiting_req_list ])
403+ req_id_select_mark = []
404+ for req_id in req_ids :
405+ req_id_select_mark .append (1 if req_id in all_req_id_set else 0 )
406+ req_id_select_mark = torch .tensor (req_id_select_mark , dtype = torch .int32 , device = "cpu" )
407+ dist .all_reduce (req_id_select_mark , op = dist .ReduceOp .MIN , group = self .mulitnode_group )
408+ select_req_ids = []
409+ for req_id , select in zip (req_ids , req_id_select_mark .numpy ()):
410+ if select == 1 :
411+ select_req_ids .append (req_id )
412+
413+ select_reqs = []
414+ for req_id in select_req_ids :
415+ for req in self .req_queue .waiting_req_list :
416+ if req .request_id == req_id :
417+ select_reqs .append (req )
418+
419+ for req in select_reqs :
420+ self .req_queue .waiting_req_list .remove (req )
421+ if select_reqs :
422+ new_batch = Batch (- 1 , reqs = select_reqs , dp_size_in_node = self .dp_size_in_node )
423+ else :
424+ new_batch = None
425+
426+ self .schedule_new_batch = Batch .merge_two_batch (self .schedule_new_batch , new_batch )
427+
428+ dist .barrier (group = self .mulitnode_group )
429+ return
430+
377431 async def _recv_new_reqs_and_schedule (self ):
378432 if not hasattr (self , "recv_max_count" ):
379433 self .recv_max_count = 64
@@ -394,9 +448,11 @@ async def _recv_new_reqs_and_schedule(self):
394448 # 当队列已经开始清空的时候,将一次接受的数量下调
395449 self .recv_max_count = 64
396450
397- # 只有当推理侧没有发生暂停的时候,才执行新的调度
398- if self ._get_paused_req_num () == 0 :
399- self ._generate_new_batch ()
451+ if self .is_multinode_tp :
452+ self ._multinode_tp_generate_new_batch ()
453+ else :
454+ if self ._get_paused_req_num () == 0 :
455+ self ._generate_new_batch ()
400456 return
401457
402458 def clean_up (self ):
0 commit comments