@@ -368,63 +368,73 @@ def _generate_new_batch(self):
368368 return
369369
370370 def _multinode_tp_generate_new_batch (self ):
371- dist .barrier (group = self .mulitnode_group )
371+ try :
372+ dist .barrier (group = self .mulitnode_group )
372373
373- # 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。
374- if self .is_multinode_tp_master :
375- new_batch = self .req_queue .generate_new_batch (
376- Batch .merge_two_batch (self .running_batch , self .schedule_new_batch )
377- )
378- if new_batch is not None :
379- req_ids = [req .request_id for req in new_batch .reqs ]
380- else :
381- req_ids = []
382- dist .broadcast_object_list ([len (req_ids )], src = 0 , group = self .mulitnode_group )
383- dist .broadcast_object_list (req_ids , src = 0 , group = self .mulitnode_group )
384- req_id_select_mark = [1 for _ in range (len (req_ids ))]
385- req_id_select_mark = torch .tensor (req_id_select_mark , dtype = torch .int32 , device = "cpu" )
386- dist .all_reduce (req_id_select_mark , op = dist .ReduceOp .MIN , group = self .mulitnode_group )
387- back_req_list = []
388- for req_id , select in zip (req_ids , req_id_select_mark .numpy ()):
389- if select == 0 :
390- req = new_batch .pop_req (req_id )
391- back_req_list .append (req )
392- self .req_queue .waiting_req_list = back_req_list + self .req_queue .waiting_req_list
393- if new_batch .is_clear ():
394- new_batch = None
395- else :
396- req_nums = [None ]
397- dist .broadcast_object_list (req_nums , src = 0 , group = self .mulitnode_group )
398- req_num = req_nums [0 ]
399- req_ids = [None for _ in range (req_num )]
400- dist .broadcast_object_list (req_ids , src = 0 , group = self .mulitnode_group )
401- all_req_id_set = set ([req .request_id for req in self .req_queue .waiting_req_list ])
402- req_id_select_mark = []
403- for req_id in req_ids :
404- req_id_select_mark .append (1 if req_id in all_req_id_set else 0 )
405- req_id_select_mark = torch .tensor (req_id_select_mark , dtype = torch .int32 , device = "cpu" )
406- dist .all_reduce (req_id_select_mark , op = dist .ReduceOp .MIN , group = self .mulitnode_group )
407- select_req_ids = []
408- for req_id , select in zip (req_ids , req_id_select_mark .numpy ()):
409- if select == 1 :
410- select_req_ids .append (req_id )
411-
412- select_reqs = []
413- for req_id in select_req_ids :
414- for req in self .req_queue .waiting_req_list :
415- if req .request_id == req_id :
416- select_reqs .append (req )
417-
418- for req in select_reqs :
419- self .req_queue .waiting_req_list .remove (req )
420- if select_reqs :
421- new_batch = Batch (- 1 , reqs = select_reqs , dp_size_in_node = self .dp_size_in_node )
374+ # 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。
375+ if self .is_multinode_tp_master :
376+ new_batch = self .req_queue .generate_new_batch (
377+ Batch .merge_two_batch (self .running_batch , self .schedule_new_batch )
378+ )
379+ if new_batch is not None :
380+ req_ids = [req .request_id for req in new_batch .reqs ]
381+ else :
382+ req_ids = []
383+ dist .broadcast_object_list ([len (req_ids )], src = 0 , group = self .mulitnode_group )
384+ if len (req_ids ) == 0 :
385+ new_batch = None
386+ else :
387+ dist .broadcast_object_list (req_ids , src = 0 , group = self .mulitnode_group )
388+ req_id_select_mark = [1 for _ in range (len (req_ids ))]
389+ req_id_select_mark = torch .tensor (req_id_select_mark , dtype = torch .int32 , device = "cpu" )
390+ dist .all_reduce (req_id_select_mark , op = dist .ReduceOp .MIN , group = self .mulitnode_group )
391+ back_req_list = []
392+ for req_id , select in zip (req_ids , req_id_select_mark .numpy ()):
393+ if select == 0 :
394+ req = new_batch .pop_req (req_id )
395+ back_req_list .append (req )
396+ self .req_queue .waiting_req_list = back_req_list + self .req_queue .waiting_req_list
397+ if new_batch .is_clear ():
398+ new_batch = None
422399 else :
423- new_batch = None
424-
425- self .schedule_new_batch = Batch .merge_two_batch (self .schedule_new_batch , new_batch )
426-
427- dist .barrier (group = self .mulitnode_group )
400+ req_nums = [None ]
401+ dist .broadcast_object_list (req_nums , src = 0 , group = self .mulitnode_group )
402+ req_num = req_nums [0 ]
403+ if req_num == 0 :
404+ new_batch = None
405+ else :
406+ req_ids = [None for _ in range (req_num )]
407+ dist .broadcast_object_list (req_ids , src = 0 , group = self .mulitnode_group )
408+ all_req_id_set = set ([req .request_id for req in self .req_queue .waiting_req_list ])
409+ req_id_select_mark = []
410+ for req_id in req_ids :
411+ req_id_select_mark .append (1 if req_id in all_req_id_set else 0 )
412+ req_id_select_mark = torch .tensor (req_id_select_mark , dtype = torch .int32 , device = "cpu" )
413+ dist .all_reduce (req_id_select_mark , op = dist .ReduceOp .MIN , group = self .mulitnode_group )
414+ select_req_ids = []
415+ for req_id , select in zip (req_ids , req_id_select_mark .numpy ()):
416+ if select == 1 :
417+ select_req_ids .append (req_id )
418+
419+ select_reqs = []
420+ for req_id in select_req_ids :
421+ for req in self .req_queue .waiting_req_list :
422+ if req .request_id == req_id :
423+ select_reqs .append (req )
424+
425+ for req in select_reqs :
426+ self .req_queue .waiting_req_list .remove (req )
427+ if select_reqs :
428+ new_batch = Batch (- 1 , reqs = select_reqs , dp_size_in_node = self .dp_size_in_node )
429+ else :
430+ new_batch = None
431+
432+ self .schedule_new_batch = Batch .merge_two_batch (self .schedule_new_batch , new_batch )
433+
434+ dist .barrier (group = self .mulitnode_group )
435+ except Exception as e :
436+ logger .exception (str (e ))
437+ raise e
428438 return
429439
430440 async def _recv_new_reqs_and_schedule (self ):
0 commit comments