Skip to content

Commit 4a6eca2

Browse files
add req pause for dp (ModelTC#822)
Co-authored-by: baishihao <baishihao@sensetime.com> Co-authored-by: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com>
1 parent 84d35d3 commit 4a6eca2

File tree

5 files changed

+53
-36
lines changed

5 files changed

+53
-36
lines changed

lightllm/server/router/batch.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@ def get_batch_decode_need_tokens(self):
3030

3131
return new_batch_decode_need_tokens
3232

33+
def get_req_list_for_dp(self, dp_index: int):
34+
if self.dp_size_in_node == 1:
35+
return self.reqs
36+
37+
req_list = []
38+
for req in self.reqs:
39+
if req.sample_params.suggested_dp_index == dp_index:
40+
req_list.append(req)
41+
return req_list
42+
3343
def filter_out_finished_req(self, shm_req_manager: ShmReqManager):
3444
unfinished_req_ids = []
3545
for req in self.reqs:

lightllm/server/router/manager.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -244,17 +244,19 @@ async def loop_for_fwd(
244244
estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(d_i)
245245
logger.debug(
246246
f"dp_i {d_i} current batch size: {len(self.running_batch.reqs)} \n"
247-
f"dp_i {d_i} paused req num: {self.req_queue.get_paused_req_num()} \n"
247+
f"dp_i {d_i} paused req num: {self.req_queue.get_paused_req_num(d_i)} \n"
248248
f"dp_i {d_i} frozen token num: {frozen_token_num} \n"
249249
f"dp_i {d_i} estimated_peak_token_count: {estimated_peak_token_count} \n"
250250
f"dp_i {d_i} token used ratio: {token_ratio1} not contain prompt cache tree unrefed token\n"
251251
f"dp_i {d_i} token used ratio: {token_ratio2} contain prompt cache tree unrefed token"
252252
)
253+
self.metric_client.gauge_set(
254+
"lightllm_batch_pause_size", self.req_queue.get_paused_req_num(d_i)
255+
)
253256
# pd decode mode need to update token_load more frequently
254257
self.req_queue.update_token_load(self.running_batch, force_update=self.is_pd_decode_mode)
255258
self.stats_tool.print_stats()
256259
self.metric_client.gauge_set("lightllm_batch_current_size", len(self.running_batch.reqs))
257-
self.metric_client.gauge_set("lightllm_batch_pause_size", self.req_queue.get_paused_req_num())
258260
self.metric_client.gauge_set("lightllm_queue_size", self.req_queue.get_wait_req_num())
259261
self.metric_client.gauge_set(
260262
"lightllm_batch_current_max_tokens",
@@ -356,23 +358,22 @@ async def _step(self):
356358
self.running_batch.merge(new_mini_batch)
357359
return
358360

359-
# 正常 decode 阶段, 如果可以直接decode就直接decode,否则通过暂停策略暂停一些请求
360-
# 释放一些管理的 token
361-
if self._can_decode(self.running_batch):
362-
self.stats_tool.count_output_tokens(self.running_batch)
363-
await self._decode_batch(self.running_batch)
364-
self._filter_runing_batch()
365-
self.has_wait_tokens += 1
366-
return
367-
else:
368-
# pause strategy
369-
paused_reqs = select_paused_reqs(
370-
self.running_batch, self.pause_strategy, self.req_queue, self.max_total_token_num
371-
)
372-
await self._pause_reqs(paused_reqs)
373-
logger.debug(f"pasued req num: {self.req_queue.get_paused_req_num()}")
374-
self.has_wait_tokens = 0
375-
return
361+
# Check if need pause some requests for decode.
362+
for dp_index in range(self.dp_size_in_node):
363+
while not self._can_decode(self.running_batch, dp_index=dp_index):
364+
# pause strategy
365+
paused_reqs = select_paused_reqs(
366+
self.running_batch, self.pause_strategy, self.req_queue, self.max_total_token_num, dp_index=dp_index
367+
)
368+
await self._pause_reqs(paused_reqs)
369+
logger.debug(f"DP index {dp_index} pasues req num: {self.req_queue.get_paused_req_num(dp_index)}")
370+
self.has_wait_tokens = 0
371+
372+
# Decode
373+
self.stats_tool.count_output_tokens(self.running_batch)
374+
await self._decode_batch(self.running_batch)
375+
self._filter_runing_batch()
376+
self.has_wait_tokens += 1
376377
return
377378

378379
async def _prefill_batch(self, batch: Batch):
@@ -416,16 +417,12 @@ def _filter_runing_batch(self):
416417
self.running_batch = None
417418
return
418419

419-
def _can_decode(self, batch: Batch):
420-
# p d 分离模式下,目前只能使用保守调度,保证请求放入进行decode的时候
421-
# 显存token肯定是够用的。
422-
# deepseekv2 dp 模式下,采用保守调度,也肯定够用
423-
if self.is_pd_run_mode or self.dp_size_in_node > 1 or self.is_safe_schedule:
420+
def _can_decode(self, batch: Batch, dp_index: int):
421+
if self.is_pd_run_mode or self.is_safe_schedule:
424422
return True
425-
426-
# 下面的判定条件,只在 dp 为 1 的情况下启用
427-
assert self.dp_size_in_node == 1
428-
return batch.get_batch_decode_need_tokens()[0] + self.get_used_tokens(0) <= self.max_total_token_num
423+
return (
424+
batch.get_batch_decode_need_tokens()[dp_index] + self.get_used_tokens(dp_index) <= self.max_total_token_num
425+
)
429426

430427
def get_used_tokens(self, dp_index):
431428
if self.args.use_dynamic_prompt_cache:

lightllm/server/router/pause_strategy.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import List, Tuple
44
from .batch import Batch, Req
55
from lightllm.server.router.req_queue.base_queue import BaseQueue
6+
from lightllm.server.router.req_queue.dp_base_queue import DpQueue
67

78

89
class Strategy:
@@ -14,13 +15,16 @@ class Fcfs(Strategy):
1415
def __init__(self) -> None:
1516
super().__init__()
1617

17-
def ordering_reqs(self, batch: Batch):
18-
reqs = [req for req in batch.reqs]
19-
return sorted(reqs, key=lambda req: req.request_id, reverse=True)
18+
def ordering_reqs(self, reqs: List):
19+
return reqs[::-1]
2020

2121

22-
def select_paused_reqs(batch: Batch, strategy: Strategy, req_queue: BaseQueue, max_total_token_num):
23-
reqs: List[Req] = strategy.ordering_reqs(batch)
22+
def select_paused_reqs(
23+
batch: Batch, strategy: Strategy, req_queue: BaseQueue, max_total_token_num: int, dp_index: int
24+
) -> List[Req]:
25+
if isinstance(req_queue, DpQueue):
26+
req_queue = req_queue.get_dp_queue(dp_index)
27+
reqs: List[Req] = strategy.ordering_reqs(batch.get_req_list_for_dp(dp_index))
2428

2529
if len(reqs) == 0:
2630
return []

lightllm/server/router/req_queue/base_queue.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def extend(self, req_group: List[Req]):
3838
self.waiting_req_list.extend(req_group)
3939
return
4040

41-
def get_paused_req_num(self):
41+
def get_paused_req_num(self, fake_dp_index: int = 0):
42+
assert fake_dp_index == 0
4243
return len(self.pause_req_dict)
4344

4445
def get_wait_req_num(self):

lightllm/server/router/req_queue/dp_base_queue.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,15 @@ def __init__(self, args, router, base_queue_class, dp_size_in_node) -> None:
1919
self.inner_queues: List[BaseQueue] = [
2020
base_queue_class(args, router, dp_index, dp_size_in_node) for dp_index in range(self.dp_size_in_node)
2121
]
22+
2223
return
2324

24-
def get_paused_req_num(self):
25-
return sum(queue.get_paused_req_num() for queue in self.inner_queues)
25+
def get_dp_queue(self, dp_index: int):
26+
assert dp_index < self.dp_size_in_node, "dp index out of range"
27+
return self.inner_queues[dp_index]
28+
29+
def get_paused_req_num(self, dp_index: int = 0):
30+
return self.inner_queues[dp_index].get_paused_req_num()
2631

2732
def get_wait_req_num(self):
2833
return sum(queue.get_wait_req_num() for queue in self.inner_queues)

0 commit comments

Comments
 (0)