Skip to content

Commit 746bf84

Browse files
committed
fix
1 parent 37ace6a commit 746bf84

File tree

4 files changed

+92
-11
lines changed

4 files changed

+92
-11
lines changed

lightllm/server/api_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
189189
parser.add_argument(
190190
"--router_max_wait_tokens",
191191
type=int,
192-
default=6,
192+
default=1,
193193
help="schedule new requests after every router_max_wait_tokens decode steps.",
194194
)
195195
parser.add_argument(

lightllm/server/core/objs/start_args_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class StartArgs:
4242
log_stats_interval: int = field(default=10)
4343
router_token_ratio: float = field(default=0.0)
4444
router_max_new_token_len: int = field(default=1024)
45-
router_max_wait_tokens: int = field(default=6)
45+
router_max_wait_tokens: int = field(default=1)
4646
dp_prefill_wait_step: int = field(default=0)
4747
disable_aggressive_schedule: bool = field(default=False)
4848
disable_dynamic_prompt_cache: bool = field(default=False)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from enum import Enum
2+
from typing import List
3+
from lightllm.utils.envs_utils import get_env_start_args
4+
from lightllm.server.router.model_infer.infer_batch import InferReq
5+
6+
class ControlState:
7+
8+
def __init__(self):
9+
self.is_aggressive_schedule = not get_env_start_args().disable_aggressive_schedule
10+
11+
# 非激进调度参数
12+
self.decode_max_step = max(1, get_env_start_args().router_max_wait_tokens)
13+
self.left_decode_num = self.decode_max_step
14+
15+
self.step_count = 0
16+
17+
18+
def select_run_way(self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq]) -> "RunWay":
19+
"""
20+
判断决策运行方式:
21+
返回值: RunWay
22+
"""
23+
self.step_count += 1
24+
if self.is_aggressive_schedule:
25+
return self._agressive_way(prefill_reqs=prefill_reqs,
26+
decode_reqs=decode_reqs)
27+
else:
28+
return self._normal_way(prefill_reqs=prefill_reqs,
29+
decode_reqs=decode_reqs)
30+
31+
def _agressive_way(self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq]):
32+
if prefill_reqs:
33+
return RunWay.PREFILL
34+
if decode_reqs:
35+
return RunWay.DECODE
36+
return RunWay.PASS
37+
38+
def _normal_way(self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq]):
39+
if decode_reqs:
40+
if self.left_decode_num > 0:
41+
self.left_decode_num -= 1
42+
return RunWay.DECODE
43+
else:
44+
if prefill_reqs:
45+
self.left_decode_num = self.decode_max_step
46+
return RunWay.PREFILL
47+
else:
48+
return RunWay.DECODE
49+
else:
50+
if prefill_reqs:
51+
self.left_decode_num = self.decode_max_step
52+
return RunWay.PREFILL
53+
else:
54+
return RunWay.PASS
55+
56+
def try_recover_paused_reqs(self) -> bool:
57+
return self.step_count % 100 == 0
58+
59+
60+
61+
class RunWay(Enum):
62+
PREFILL = 1
63+
DECODE = 2
64+
PASS = 3
65+
66+
def is_prefill(self):
67+
return self == RunWay.PREFILL
68+
69+
def is_decode(self):
70+
return self == RunWay.DECODE
71+
72+
def is_pass(self):
73+
return self == RunWay.PASS

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,17 @@
1717
from lightllm.utils.log_utils import init_logger
1818
from lightllm.utils.dist_utils import get_current_device_id
1919
from lightllm.utils.envs_utils import get_env_start_args
20+
from .control_state import ControlState
2021

2122
logger = init_logger(__name__)
2223

2324

2425
class ChunkedPrefillBackend(ModeBackend):
2526
def __init__(self) -> None:
2627
super().__init__()
28+
29+
# 用于控制每一步是执行prefill 和 decode 还是跳过
30+
self.control_state_machine = ControlState()
2731

2832
# 在 mtp 模式下切换绑定的prefill 和 decode 函数
2933
if get_env_start_args().mtp_mode:
@@ -43,25 +47,29 @@ def infer_loop(self):
4347

4448
self._try_read_new_reqs()
4549

46-
prefill_reqs, decode_reqs = self._get_classed_reqs()
47-
if prefill_reqs:
50+
prefill_reqs, decode_reqs = self._get_classed_reqs(recover_paused=self.control_state_machine.try_recover_paused_reqs())
51+
52+
run_way = self.control_state_machine.select_run_way(prefill_reqs=prefill_reqs,
53+
decode_reqs=decode_reqs)
54+
55+
if run_way.is_prefill():
4856
self.prefill(
4957
event_pack=event_pack,
5058
prefill_reqs=prefill_reqs,
5159
)
5260
continue
53-
54-
if decode_reqs:
61+
elif run_way.is_decode():
5562
self.decode(
5663
event_pack=event_pack,
5764
decode_reqs=decode_reqs,
5865
)
5966
continue
60-
61-
event_pack.notify_post_handle_and_wait_pre_post_handle()
62-
event_pack.notify_forward_and_wait_post_handle()
63-
event_pack.notify_pre_post_handle()
64-
continue
67+
elif run_way.is_pass():
68+
event_pack.notify_post_handle_and_wait_pre_post_handle()
69+
event_pack.notify_forward_and_wait_post_handle()
70+
event_pack.notify_pre_post_handle()
71+
continue
72+
6573
except BaseException as e:
6674
self.logger.exception(str(e))
6775
raise e

0 commit comments

Comments
 (0)