Skip to content

Commit b172eaf

Browse files
committed
fix
1 parent 4c111a3 commit b172eaf

File tree

7 files changed

+78
-27
lines changed

7 files changed

+78
-27
lines changed

lightllm/server/router/batch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ def filter_out_finished_req(self, shm_req_manager: ShmReqManager):
5454
self.id_to_reqs = {req.request_id: req for req in self.reqs}
5555
return
5656

57-
def pop_req(self, req_id):
57+
def pop_req(self, req_id) -> Req:
5858
self.reqs = [req for req in self.reqs if req.request_id != req_id]
59-
self.id_to_reqs.pop(req_id)
60-
return
59+
req = self.id_to_reqs.pop(req_id)
60+
return req
6161

6262
def is_clear(self):
6363
return len(self.reqs) == 0

lightllm/server/router/manager.py

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por
4545
# 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容
4646
self.dp_size_in_node = max(1, args.dp // self.nnodes)
4747
self.is_multinode_tp = args.nnodes > 1 and args.dp == 1
48+
self.is_multinode_tp_master = self.is_multinode_tp and args.node_rank == 0
49+
self.is_multinode_tp_slave = self.is_multinode_tp and args.node_rank != 0
4850
self.is_multinode_and_multidp = args.nnodes > 1 and args.dp > 1
4951
# 判断是否是保守调度,保守调度不会发生暂停 req 的情况,但是有些场景可能影响吞吐
5052
self.is_safe_schedule = args.router_token_ratio == 0.0
@@ -359,21 +361,73 @@ def _add_req(self, group_req_indexes: GroupReqIndexes):
359361
return
360362

361363
def _generate_new_batch(self):
362-
limit_router_queue_length = None
363-
if self.is_multinode_tp:
364-
# 使用 all_reduce 获取最小值
365-
limit_router_queue_length = len(self.req_queue.waiting_req_list)
366-
limit_router_queue_length_tensor = torch.tensor(limit_router_queue_length, dtype=torch.int32, device="cpu")
367-
dist.all_reduce(limit_router_queue_length_tensor, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
368-
limit_router_queue_length = limit_router_queue_length_tensor.item()
369-
370364
# 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。
371365
new_batch = self.req_queue.generate_new_batch(
372-
Batch.merge_two_batch(self.running_batch, self.schedule_new_batch), limit_router_queue_length
366+
Batch.merge_two_batch(self.running_batch, self.schedule_new_batch)
373367
)
374368
self.schedule_new_batch = Batch.merge_two_batch(self.schedule_new_batch, new_batch)
375369
return
376370

371+
def _multinode_tp_generate_new_batch(self):
372+
dist.barrier(group=self.mulitnode_group)
373+
374+
# 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。
375+
if self.is_multinode_tp_master:
376+
new_batch = self.req_queue.generate_new_batch(
377+
Batch.merge_two_batch(self.running_batch, self.schedule_new_batch)
378+
)
379+
if new_batch is not None:
380+
req_ids = [req.request_id for req in new_batch.reqs]
381+
else:
382+
req_ids = []
383+
dist.broadcast_object_list([len(req_ids)], src=0, group=self.mulitnode_group)
384+
dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group)
385+
req_id_select_mark = [1 for _ in range(len(req_ids))]
386+
req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu")
387+
dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
388+
back_req_list = []
389+
for req_id, select in zip(req_ids, req_id_select_mark.numpy()):
390+
if select == 0:
391+
req = new_batch.pop_req(req_id)
392+
back_req_list.append(req)
393+
self.req_queue.waiting_req_list = back_req_list + self.req_queue.waiting_req_list
394+
if new_batch.is_clear():
395+
new_batch = None
396+
else:
397+
req_nums = [None]
398+
dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group)
399+
req_num = req_nums[0]
400+
req_ids = [None for _ in range(req_num)]
401+
dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group)
402+
all_req_id_set = set([req.request_id for req in self.req_queue.waiting_req_list])
403+
req_id_select_mark = []
404+
for req_id in req_ids:
405+
req_id_select_mark.append(1 if req_id in all_req_id_set else 0)
406+
req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu")
407+
dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
408+
select_req_ids = []
409+
for req_id, select in zip(req_ids, req_id_select_mark.numpy()):
410+
if select == 1:
411+
select_req_ids.append(req_id)
412+
413+
select_reqs = []
414+
for req_id in select_req_ids:
415+
for req in self.req_queue.waiting_req_list:
416+
if req.request_id == req_id:
417+
select_reqs.append(req)
418+
419+
for req in select_reqs:
420+
self.req_queue.waiting_req_list.remove(req)
421+
if select_reqs:
422+
new_batch = Batch(-1, reqs=select_reqs, dp_size_in_node=self.dp_size_in_node)
423+
else:
424+
new_batch = None
425+
426+
self.schedule_new_batch = Batch.merge_two_batch(self.schedule_new_batch, new_batch)
427+
428+
dist.barrier(group=self.mulitnode_group)
429+
return
430+
377431
async def _recv_new_reqs_and_schedule(self):
378432
if not hasattr(self, "recv_max_count"):
379433
self.recv_max_count = 64
@@ -394,9 +448,11 @@ async def _recv_new_reqs_and_schedule(self):
394448
# 当队列已经开始清空的时候,将一次接受的数量下调
395449
self.recv_max_count = 64
396450

397-
# 只有当推理侧没有发生暂停的时候,才执行新的调度
398-
if self._get_paused_req_num() == 0:
399-
self._generate_new_batch()
451+
if self.is_multinode_tp:
452+
self._multinode_tp_generate_new_batch()
453+
else:
454+
if self._get_paused_req_num() == 0:
455+
self._generate_new_batch()
400456
return
401457

402458
def clean_up(self):

lightllm/server/router/req_queue/base_queue.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,10 @@ def get_batch_dp_req_size(self, current_batch: Batch):
5858

5959
return len([req for req in current_batch.reqs if req.sample_params.suggested_dp_index == self.dp_index])
6060

61-
def generate_new_batch(self, current_batch: Batch, limit_router_queue_length: int = None):
61+
def generate_new_batch(self, current_batch: Batch):
6262
"""
6363
args:
6464
current_batch: current batch
65-
limit_router_queue_length: the least length of waiting list across all nodes.
66-
It only works when nnodes > 1 and dp_size == 1.
6765
return:
6866
new batch
6967
"""

lightllm/server/router/req_queue/chunked_prefill/beam_impl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,12 @@ def _can_add_new_group_reqs(self, cur_handle_group_reqs: List[Req], is_busy, new
6969
return False, new_batch_first_router_need_tokens
7070

7171
# @calculate_time(show=True, min_cost_ms=10)
72-
def generate_new_batch(self, current_batch: Batch, limit_router_queue_length: int = None):
72+
def generate_new_batch(self, current_batch: Batch):
7373
if len(self.waiting_req_list) == 0:
7474
return None
7575

7676
# 如果当前已经被调度的请求数量超过了上限,直接不调度新的请求了。
77-
exist_req_num = self.get_batch_dp_req_size(current_batch) + len(self.pause_req_dict)
77+
exist_req_num = self.get_batch_dp_req_size(current_batch)
7878
req_is_full = exist_req_num >= self.running_max_req_size
7979
if req_is_full:
8080
return None

lightllm/server/router/req_queue/chunked_prefill/impl.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens
5454
return False, new_batch_first_router_need_tokens
5555

5656
# @calculate_time(show=True, min_cost_ms=10)
57-
def generate_new_batch(self, current_batch: Batch, limit_router_queue_length: int = None):
57+
def generate_new_batch(self, current_batch: Batch):
5858
if len(self.waiting_req_list) == 0:
5959
return None
6060

@@ -75,10 +75,7 @@ def generate_new_batch(self, current_batch: Batch, limit_router_queue_length: in
7575
abort_req_list = []
7676
aborted_count = 0
7777

78-
if limit_router_queue_length is None:
79-
waiting_queue = self.waiting_req_list
80-
else:
81-
waiting_queue = self.waiting_req_list[:limit_router_queue_length]
78+
waiting_queue = self.waiting_req_list
8279

8380
for req in waiting_queue:
8481
if req.is_aborted:

lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def _init_cache_list(self, current_batch: Batch, is_busy):
2424
return
2525

2626
# @calculate_time(show=True, min_cost_ms=10)
27-
def generate_new_batch(self, current_batch: Batch, limit_router_queue_length: int = None):
27+
def generate_new_batch(self, current_batch: Batch):
2828
if len(self.waiting_req_list) == 0:
2929
return None
3030

lightllm/server/router/req_queue/dp_base_queue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def get_wait_req_num(self):
3030
return sum(queue.get_wait_req_num() for queue in self.inner_queues)
3131

3232
# @calculate_time(show=True, min_cost_ms=10)
33-
def generate_new_batch(self, current_batch: Batch, limit_router_queue_length: int = None):
33+
def generate_new_batch(self, current_batch: Batch):
3434
batches = [
3535
self.inner_queues[dp_index].generate_new_batch(current_batch) for dp_index in range(self.dp_size_in_node)
3636
]

0 commit comments

Comments
 (0)