Skip to content

Commit 31a3ed8

Browse files
committed
fix multinode tp error
1 parent c8bd663 commit 31a3ed8

File tree

1 file changed

+65
-55
lines changed

1 file changed

+65
-55
lines changed

lightllm/server/router/manager.py

Lines changed: 65 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -368,63 +368,73 @@ def _generate_new_batch(self):
368368
return
369369

370370
def _multinode_tp_generate_new_batch(self):
371-
dist.barrier(group=self.mulitnode_group)
371+
try:
372+
dist.barrier(group=self.mulitnode_group)
372373

373-
# 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。
374-
if self.is_multinode_tp_master:
375-
new_batch = self.req_queue.generate_new_batch(
376-
Batch.merge_two_batch(self.running_batch, self.schedule_new_batch)
377-
)
378-
if new_batch is not None:
379-
req_ids = [req.request_id for req in new_batch.reqs]
380-
else:
381-
req_ids = []
382-
dist.broadcast_object_list([len(req_ids)], src=0, group=self.mulitnode_group)
383-
dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group)
384-
req_id_select_mark = [1 for _ in range(len(req_ids))]
385-
req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu")
386-
dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
387-
back_req_list = []
388-
for req_id, select in zip(req_ids, req_id_select_mark.numpy()):
389-
if select == 0:
390-
req = new_batch.pop_req(req_id)
391-
back_req_list.append(req)
392-
self.req_queue.waiting_req_list = back_req_list + self.req_queue.waiting_req_list
393-
if new_batch.is_clear():
394-
new_batch = None
395-
else:
396-
req_nums = [None]
397-
dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group)
398-
req_num = req_nums[0]
399-
req_ids = [None for _ in range(req_num)]
400-
dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group)
401-
all_req_id_set = set([req.request_id for req in self.req_queue.waiting_req_list])
402-
req_id_select_mark = []
403-
for req_id in req_ids:
404-
req_id_select_mark.append(1 if req_id in all_req_id_set else 0)
405-
req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu")
406-
dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
407-
select_req_ids = []
408-
for req_id, select in zip(req_ids, req_id_select_mark.numpy()):
409-
if select == 1:
410-
select_req_ids.append(req_id)
411-
412-
select_reqs = []
413-
for req_id in select_req_ids:
414-
for req in self.req_queue.waiting_req_list:
415-
if req.request_id == req_id:
416-
select_reqs.append(req)
417-
418-
for req in select_reqs:
419-
self.req_queue.waiting_req_list.remove(req)
420-
if select_reqs:
421-
new_batch = Batch(-1, reqs=select_reqs, dp_size_in_node=self.dp_size_in_node)
374+
# 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。
375+
if self.is_multinode_tp_master:
376+
new_batch = self.req_queue.generate_new_batch(
377+
Batch.merge_two_batch(self.running_batch, self.schedule_new_batch)
378+
)
379+
if new_batch is not None:
380+
req_ids = [req.request_id for req in new_batch.reqs]
381+
else:
382+
req_ids = []
383+
dist.broadcast_object_list([len(req_ids)], src=0, group=self.mulitnode_group)
384+
if len(req_ids) == 0:
385+
new_batch = None
386+
else:
387+
dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group)
388+
req_id_select_mark = [1 for _ in range(len(req_ids))]
389+
req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu")
390+
dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
391+
back_req_list = []
392+
for req_id, select in zip(req_ids, req_id_select_mark.numpy()):
393+
if select == 0:
394+
req = new_batch.pop_req(req_id)
395+
back_req_list.append(req)
396+
self.req_queue.waiting_req_list = back_req_list + self.req_queue.waiting_req_list
397+
if new_batch.is_clear():
398+
new_batch = None
422399
else:
423-
new_batch = None
424-
425-
self.schedule_new_batch = Batch.merge_two_batch(self.schedule_new_batch, new_batch)
426-
427-
dist.barrier(group=self.mulitnode_group)
400+
req_nums = [None]
401+
dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group)
402+
req_num = req_nums[0]
403+
if req_num == 0:
404+
new_batch = None
405+
else:
406+
req_ids = [None for _ in range(req_num)]
407+
dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group)
408+
all_req_id_set = set([req.request_id for req in self.req_queue.waiting_req_list])
409+
req_id_select_mark = []
410+
for req_id in req_ids:
411+
req_id_select_mark.append(1 if req_id in all_req_id_set else 0)
412+
req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu")
413+
dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
414+
select_req_ids = []
415+
for req_id, select in zip(req_ids, req_id_select_mark.numpy()):
416+
if select == 1:
417+
select_req_ids.append(req_id)
418+
419+
select_reqs = []
420+
for req_id in select_req_ids:
421+
for req in self.req_queue.waiting_req_list:
422+
if req.request_id == req_id:
423+
select_reqs.append(req)
424+
425+
for req in select_reqs:
426+
self.req_queue.waiting_req_list.remove(req)
427+
if select_reqs:
428+
new_batch = Batch(-1, reqs=select_reqs, dp_size_in_node=self.dp_size_in_node)
429+
else:
430+
new_batch = None
431+
432+
self.schedule_new_batch = Batch.merge_two_batch(self.schedule_new_batch, new_batch)
433+
434+
dist.barrier(group=self.mulitnode_group)
435+
except Exception as e:
436+
logger.exception(str(e))
437+
raise e
428438
return
429439

430440
async def _recv_new_reqs_and_schedule(self):

0 commit comments

Comments
 (0)