Skip to content

Commit 242590f

Browse files
authored
token_healing mode and out constraint mode use chuncked prefill. (#846)
1 parent 59bfff7 commit 242590f

File tree

17 files changed

+243
-228
lines changed

17 files changed

+243
-228
lines changed

docs/CN/source/getting_started/quickstart.rst

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@
113113
$ --tokenizer_mode fast \
114114
$ --pd_master_ip /your/host/ip \
115115
$ --pd_master_port 60011 \
116-
$ --use_dynamic_prompt_cache \
117116
$ --max_req_total_len 16000 \
118117
$ --running_max_req_size 128 \
119118
$ --disable_cudagraph
@@ -133,8 +132,7 @@
133132
$ --graph_max_batch_size 16 \
134133
$ --tokenizer_mode fast \
135134
$ --pd_master_ip /your/host/ip \
136-
$ --pd_master_port 60011 \
137-
$ --use_dynamic_prompt_cache
135+
$ --pd_master_port 60011
138136
139137
.. note::
140138
prefill和decoding阶段的tp大小保持一致, 目前可以支持 prefill 和 decode 节点的数量是变化的,同时prefill 和 decode可以跨机部署。
@@ -215,7 +213,6 @@ $ --config_server_port 60088 \
215213
$ --nccl_port 2732 \
216214
$ --max_total_token_num 400000 \
217215
$ --tokenizer_mode fast \
218-
$ --use_dynamic_prompt_cache \
219216
$ --max_req_total_len 16000 \
220217
$ --running_max_req_size 128 \
221218
$ --disable_cudagraph \
@@ -236,7 +233,6 @@ $ --config_server_port 60088 \
236233
$ --graph_max_len_in_batch 2048 \
237234
$ --graph_max_batch_size 16 \
238235
$ --tokenizer_mode fast \
239-
$ --use_dynamic_prompt_cache \
240236
$ --config_server_host <config_server_host> \
241237
$ --config_server_port <config_server_port>
242238

docs/EN/source/getting_started/quickstart.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ Open a new terminal and run the prefill service
110110
$ --tokenizer_mode fast \
111111
$ --pd_master_ip /your/host/ip \
112112
$ --pd_master_port 60011 \
113-
$ --use_dynamic_prompt_cache \
114113
$ --max_req_total_len 16000 \
115114
$ --running_max_req_size 128 \
116115
$ --disable_cudagraph
@@ -130,8 +129,7 @@ Open a new terminal and run the decoding service
130129
$ --graph_max_batch_size 16 \
131130
$ --tokenizer_mode fast \
132131
$ --pd_master_ip /your/host/ip \
133-
$ --pd_master_port 60011 \
134-
$ --use_dynamic_prompt_cache
132+
$ --pd_master_port 60011
135133
136134
.. note::
137135
The tp size for the prefill and decoding stages should remain consistent.

lightllm/server/api_cli.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,10 @@ def make_argument_parser() -> argparse.ArgumentParser:
184184
disabling it allows the router_max_wait_tokens parameter to work more effectively.""",
185185
)
186186

187-
parser.add_argument("--use_dynamic_prompt_cache", action="store_true", help="use_dynamic_prompt_cache test")
187+
parser.add_argument(
188+
"--use_dynamic_prompt_cache", action="store_true", help="This argument is deprecated and no longer in use."
189+
)
190+
parser.add_argument("--disable_dynamic_prompt_cache", action="store_true", help="disable dynamic prompt cache")
188191

189192
parser.add_argument("--chunked_prefill_size", type=int, default=8192, help="chunked prefill size")
190193
parser.add_argument("--disable_chunked_prefill", action="store_true", help="whether to disable chunked prefill")

lightllm/server/api_start.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,18 +95,17 @@ def normal_or_p_d_start(args):
9595
assert [
9696
args.disable_chunked_prefill,
9797
args.diverse_mode,
98-
args.token_healing_mode,
9998
args.use_reward_model,
10099
args.return_all_prompt_logprobs,
101-
args.output_constraint_mode != "none",
102100
].count(True) <= 1
103-
# 部分模式目前还无法与dynamic_prompt_cache一起跑,to do。
104-
if args.use_dynamic_prompt_cache:
105-
assert args.token_healing_mode is False
106101

107102
# chuncked prefill 需要和 dynamic_prompt_cache 一起使能
108103
if not args.disable_chunked_prefill:
109-
assert args.use_dynamic_prompt_cache is True
104+
assert args.disable_dynamic_prompt_cache is False
105+
if args.output_constraint_mode != "none":
106+
assert args.disable_dynamic_prompt_cache is False
107+
if args.token_healing_mode:
108+
assert args.disable_dynamic_prompt_cache is False
110109

111110
# 部分模式还不能支持与高级动态调度算法协同,to do.
112111
if args.diverse_mode:

lightllm/server/core/objs/req.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -272,29 +272,6 @@ def get_first_router_need_tokens(self):
272272
return self.input_len + self.shm_cur_output_len
273273

274274

275-
class TokenHealingReq(NormalReq):
276-
_pack_ = 4
277-
278-
def post_init(
279-
self,
280-
):
281-
for prefix_token_num in range(2, -1, -1):
282-
if self.input_len > prefix_token_num:
283-
self.input_len -= prefix_token_num
284-
self.prefix_token_ids.set_token_ids(
285-
self.shm_prompt_ids.arr[self.input_len : (self.input_len + prefix_token_num)]
286-
)
287-
break
288-
289-
# 因为原始的输出token数量,会被中间的前缀补全占用decode次数,
290-
# 所以默认多添加一些decode步数, token healing mode 下,由于
291-
# 估计的生成token数据对应的生存周期可能会不准确,所以为了缓解调
292-
# 度带来的显存估计问题,对于生成token的长度 + 6来缓解可能的估计
293-
# 错误问题。
294-
self.sample_params.max_new_tokens = self.sample_params.max_new_tokens + self.prefix_token_ids.size + 6
295-
return
296-
297-
298275
class ChunkedPrefillReq(Req):
299276
_pack_ = 4
300277

@@ -333,3 +310,26 @@ def get_decode_need_tokens(self):
333310
def get_first_router_need_tokens(self):
334311

335312
return min(self.input_len + self.shm_cur_output_len, self.chunked_prefill_size)
313+
314+
315+
class TokenHealingReq(ChunkedPrefillReq):
316+
_pack_ = 4
317+
318+
def post_init(
319+
self,
320+
):
321+
for prefix_token_num in range(2, -1, -1):
322+
if self.input_len > prefix_token_num:
323+
self.input_len -= prefix_token_num
324+
self.prefix_token_ids.set_token_ids(
325+
self.shm_prompt_ids.arr[self.input_len : (self.input_len + prefix_token_num)]
326+
)
327+
break
328+
329+
# 因为原始的输出token数量,会被中间的前缀补全占用decode次数,
330+
# 所以默认多添加一些decode步数, token healing mode 下,由于
331+
# 估计的生成token数据对应的生存周期可能会不准确,所以为了缓解调
332+
# 度带来的显存估计问题,对于生成token的长度 + 6来缓解可能的估计
333+
# 错误问题。
334+
self.sample_params.max_new_tokens = self.sample_params.max_new_tokens + self.prefix_token_ids.size + 6
335+
return

lightllm/server/core/objs/start_args_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class StartArgs:
3939
router_max_new_token_len: int = field(default=1024)
4040
router_max_wait_tokens: int = field(default=6)
4141
disable_aggressive_schedule: bool = field(default=False)
42-
use_dynamic_prompt_cache: bool = field(default=False)
42+
disable_dynamic_prompt_cache: bool = field(default=False)
4343
chunked_prefill_size: int = field(default=8192)
4444
disable_chunked_prefill: bool = field(default=False)
4545
diverse_mode: bool = field(default=False)

lightllm/server/router/manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ async def wait_to_model_ready(self):
163163
"is_token_healing": self.args.token_healing_mode,
164164
"return_all_prompt_logprobs": self.args.return_all_prompt_logprobs,
165165
"use_reward_model": self.args.use_reward_model,
166-
"use_dynamic_prompt_cache": self.args.use_dynamic_prompt_cache,
166+
"disable_dynamic_prompt_cache": self.args.disable_dynamic_prompt_cache,
167167
"data_type": self.args.data_type,
168168
"eos_id": self.eos_id,
169169
"diverse_mode": self.args.diverse_mode,
@@ -182,7 +182,7 @@ async def wait_to_model_ready(self):
182182
if self.max_total_token_num is None:
183183
self.max_total_token_num = await self.model_rpc_client.get_max_total_token_num()
184184
self.args.max_total_token_num = self.max_total_token_num
185-
if self.args.use_dynamic_prompt_cache:
185+
if not self.args.disable_dynamic_prompt_cache:
186186
self.radix_cache_client = RadixCacheReadOnlyClient(
187187
get_unique_server_name(),
188188
self.max_total_token_num,
@@ -425,7 +425,7 @@ def _can_decode(self, batch: Batch, dp_index: int):
425425
)
426426

427427
def get_used_tokens(self, dp_index):
428-
if self.args.use_dynamic_prompt_cache:
428+
if not self.args.disable_dynamic_prompt_cache:
429429
return (
430430
self.max_total_token_num
431431
- self.read_only_statics_mem_manager.get_unrefed_token_num(dp_index)

lightllm/server/router/model_infer/mode_backend/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from .continues_batch.impl_for_reward_model import RewardModelBackend
44
from .chunked_prefill.impl import ChunkedPrefillBackend
55
from .diverse_backend.impl import DiversehBackend
6-
from .continues_batch.impl_for_token_healing import TokenHealingBackend
7-
from .continues_batch.impl_for_outlines_constraint_mode import OutlinesConstraintBackend
6+
from .chunked_prefill.impl_for_token_healing import TokenHealingBackend
7+
from .chunked_prefill.impl_for_outlines_constraint_mode import OutlinesConstraintBackend
88
from .chunked_prefill.impl_for_first_token_constraint_mode import FirstTokenConstraintBackend
99
from .dp_backend.impl import DPChunkedPrefillBackend
1010
from .continues_batch.pd_mode.prefill_node_impl.prefill_impl import ChunckedPrefillForPrefillNode
1111
from .continues_batch.pd_mode.decode_node_impl.decode_impl import ContinuesBatchBackendForDecodeNode
12-
from .continues_batch.impl_for_xgrammar_mode import XgrammarBackend
12+
from .chunked_prefill.impl_for_xgrammar_mode import XgrammarBackend
1313
from .continues_batch.pd_mode.prefill_node_impl.prefill_impl_for_dp_chuncked import DPChunkedForPrefillNode
1414
from .continues_batch.pd_mode.decode_node_impl.decode_impl_for_dp import DPForDecodeNode

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def init_model(self, kvargs):
8484
self.disable_chunked_prefill = kvargs.get("disable_chunked_prefill", False)
8585
self.chunked_prefill_size = kvargs.get("chunked_prefill_size", None)
8686
self.return_all_prompt_logprobs = kvargs.get("return_all_prompt_logprobs", False)
87-
self.use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False)
87+
self.use_dynamic_prompt_cache = not kvargs.get("disable_dynamic_prompt_cache", False)
8888
self.eos_id: List[int] = kvargs.get("eos_id", [2])
8989
self.disable_cudagraph = kvargs.get("disable_cudagraph", False)
9090

lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_outlines_constraint_mode.py renamed to lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py

Lines changed: 65 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import os
22
import shutil
33
import torch
4-
from .impl import ContinuesBatchBackend
5-
from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end
4+
from .impl import ChunkedPrefillBackend
65
from lightllm.server.core.objs import FinishStatus
7-
from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq, InferSamplingParams
6+
from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq
87
from lightllm.server.router.model_infer.mode_backend.generic_pre_process import (
98
prepare_prefill_inputs,
109
prepare_decode_inputs,
@@ -17,7 +16,7 @@
1716
logger = init_logger(__name__)
1817

1918

20-
class OutlinesConstraintBackend(ContinuesBatchBackend):
19+
class OutlinesConstraintBackend(ChunkedPrefillBackend):
2120
def __init__(self) -> None:
2221
super().__init__()
2322

@@ -45,63 +44,23 @@ def init_custom(self):
4544
logger.info(f"eos_ids {self.tokenizer.eos_token_ids}")
4645
return
4746

48-
def prefill(self, reqs: List[Tuple]):
49-
50-
req_ids = self._init_reqs(reqs)
51-
52-
# import here, 当你不使用这个模式,缺少这些依赖也可以运行
53-
from outlines.fsm.guide import RegexGuide
54-
55-
req_objs = self._trans_req_ids_to_req_objs(req_ids)
56-
kwargs, run_reqs = prepare_prefill_inputs(req_objs, is_chuncked_mode=False, is_multimodal=self.is_multimodal)
57-
run_reqs: List[InferReq] = run_reqs
58-
59-
logics = self.model.forward(**kwargs)
60-
61-
# 对于不能满足前缀匹配的logic位置,将其logics设置为一个较大负值,将其概率掩盖为 0
62-
mask = torch.ones_like(logics, dtype=torch.bool)
63-
for i, run_obj in enumerate(run_reqs):
64-
run_obj: InferReq = run_obj
65-
sample_params = run_obj.sampling_param
66-
if sample_params.regular_constraint is not None:
67-
sample_params.regex_guide = RegexGuide.from_regex(sample_params.regular_constraint, self.tokenizer)
68-
self._mask_req_out_token(i, run_obj, mask)
69-
70-
logics[mask] = -1000000.0
71-
72-
next_token_ids, next_token_probs = sample(logics, run_reqs, self.eos_id)
73-
next_token_ids = next_token_ids.detach().cpu().numpy()
74-
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
75-
76-
self._post_handle(
77-
run_reqs,
78-
next_token_ids,
79-
next_token_logprobs,
80-
is_chuncked_mode=False,
81-
do_filter_finished_reqs=False,
82-
extra_post_req_handle_func=self._update_state_fsm,
83-
)
84-
85-
return
86-
8747
def decode(self):
8848
uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs(
8949
g_infer_context.infer_req_ids
9050
)
91-
assert len(uninit_reqs) == 0
92-
assert len(prefill_reqs) == 0
9351

9452
if aborted_reqs:
9553
g_infer_context.filter_reqs(aborted_reqs)
9654

55+
# 先 decode
9756
if decode_reqs:
9857
kwargs, run_reqs = prepare_decode_inputs(decode_reqs)
99-
run_reqs: List[InferReq] = run_reqs
100-
10158
logits = self.model.forward(**kwargs)
59+
self._overlap_req_init_and_filter(
60+
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True
61+
)
10262

103-
self._overlap_req_init_and_filter(uninit_reqs=[], ok_finished_reqs=ok_finished_reqs, clear_list=True)
104-
63+
self._init_guide_infos(run_reqs)
10564
all_has_no_constraint = all([not e.sampling_param.has_constraint_setting() for e in run_reqs])
10665
if not all_has_no_constraint:
10766
mask = torch.ones_like(logits, dtype=torch.bool)
@@ -112,7 +71,6 @@ def decode(self):
11271
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
11372
next_token_ids = next_token_ids.detach().cpu().numpy()
11473
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
115-
11674
self._post_handle(
11775
run_reqs,
11876
next_token_ids,
@@ -121,8 +79,42 @@ def decode(self):
12179
do_filter_finished_reqs=False,
12280
extra_post_req_handle_func=self._update_state_fsm,
12381
)
82+
logits = None
83+
84+
# 再 prefill
85+
if len(decode_reqs) == 0 or (self.forward_step % self.max_wait_step == 0) or (self.need_prefill_count > 0):
86+
if prefill_reqs:
87+
self.need_prefill_count -= 1
88+
kwargs, run_reqs = prepare_prefill_inputs(
89+
prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal
90+
)
91+
logits = self.model.forward(**kwargs)
92+
self._overlap_req_init_and_filter(
93+
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True
94+
)
95+
# 对于不能满足前缀匹配的logic位置,将其logics设置为一个较大负值,将其概率掩盖为 0
96+
self._init_guide_infos(run_reqs)
97+
mask = torch.ones_like(logits, dtype=torch.bool)
98+
for i, run_obj in enumerate(run_reqs):
99+
self._mask_req_out_token(i, run_obj, mask)
124100

125-
self._overlap_req_init_and_filter(uninit_reqs=[], ok_finished_reqs=ok_finished_reqs, clear_list=True)
101+
logits[mask] = -1000000.0
102+
103+
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
104+
next_token_ids = next_token_ids.detach().cpu().numpy()
105+
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
106+
self._post_handle(
107+
run_reqs,
108+
next_token_ids,
109+
next_token_logprobs,
110+
is_chuncked_mode=True,
111+
do_filter_finished_reqs=False,
112+
extra_post_req_handle_func=self._update_state_fsm,
113+
)
114+
logits = None
115+
116+
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
117+
self.forward_step += 1
126118
return
127119

128120
def _update_state_fsm(self, req_obj: InferReq, next_token_id, next_token_logprob):
@@ -138,13 +130,28 @@ def _update_state_fsm(self, req_obj: InferReq, next_token_id, next_token_logprob
138130
def _mask_req_out_token(self, i, run_obj: InferReq, mask):
139131
from outlines.fsm.guide import RegexGuide
140132

141-
sample_params = run_obj.sampling_param
142-
if sample_params.regular_constraint is not None:
143-
regex_guide: RegexGuide = sample_params.regex_guide
144-
ok_token_id_list = regex_guide.get_next_instruction(sample_params.fsm_current_state).tokens
145-
mask[i, ok_token_id_list] = False
146-
elif sample_params.allowed_token_ids is not None:
147-
mask[i, sample_params.allowed_token_ids] = False
133+
if run_obj.get_chuncked_input_token_len() == run_obj.get_cur_total_len():
134+
# this run_obj is ready to gen next token.
135+
sample_params = run_obj.sampling_param
136+
if sample_params.regular_constraint is not None:
137+
regex_guide: RegexGuide = sample_params.regex_guide
138+
ok_token_id_list = regex_guide.get_next_instruction(sample_params.fsm_current_state).tokens
139+
mask[i, ok_token_id_list] = False
140+
elif sample_params.allowed_token_ids is not None:
141+
mask[i, sample_params.allowed_token_ids] = False
142+
else:
143+
mask[i, :] = False
148144
else:
145+
# no constraint
149146
mask[i, :] = False
150147
return
148+
149+
def _init_guide_infos(self, run_reqs: List[InferReq]):
150+
from outlines.fsm.guide import RegexGuide
151+
152+
for i, run_obj in enumerate(run_reqs):
153+
run_obj: InferReq = run_obj
154+
sample_params = run_obj.sampling_param
155+
if sample_params.regular_constraint is not None:
156+
if not hasattr(sample_params, "regex_guide"):
157+
sample_params.regex_guide = RegexGuide.from_regex(sample_params.regular_constraint, self.tokenizer)

0 commit comments

Comments
 (0)