Skip to content

Commit fcc8540

Browse files
author
niushengxiao
committed
feat: move stop string matching to detokenization
1 parent c9e31b3 commit fcc8540

File tree

14 files changed

+106
-54
lines changed

14 files changed

+106
-54
lines changed

lightllm/server/api_openai.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@ async def _collect_generation_results(
539539
earliest_stop_index = actual_stop_index
540540

541541
if earliest_stop_index < len(final_text):
542+
logger.info(f"removed stop sequence in tail: '{final_text[earliest_stop_index:]}'")
542543
final_text = final_text[:earliest_stop_index]
543544

544545
return {
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .group_req import GroupReqIndexes, GroupReqObjs, AbortedReqCmd
1+
from .group_req import GroupReqIndexes, GroupReqObjs, AbortedReqCmd, StopStrMatchedReqCmd

lightllm/server/core/objs/io_objs/group_req.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,8 @@ def to_group_req_index(self):
3131
@dataclass
3232
class AbortedReqCmd:
3333
req_id: int
34+
35+
36+
@dataclass
37+
class StopStrMatchedReqCmd:
38+
req_id: int

lightllm/server/core/objs/req.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ def get_status(self):
3232
def is_finished(self):
3333
return self.FINISHED_STOP <= self.status <= self.FINISHED_LENGTH
3434

35+
def is_stoped(self):
36+
return self.status == self.FINISHED_STOP
37+
3538
def get_finish_reason(self):
3639
if self.status == self.FINISHED_STOP:
3740
return "stop"
@@ -97,6 +100,8 @@ class Req(ctypes.Structure):
97100
("mtp_accepted_token_num", ctypes.c_int),
98101
# mtp_step 保存一个mtp使用的常量参数,用于快速访问,不会被外部输入初始化
99102
("_mtp_step", ctypes.c_int),
103+
# stop_str_matched用于判断停止字符串是否匹配成功
104+
("stop_str_matched", ctypes.c_bool),
100105
]
101106

102107
def get_str(self):
@@ -150,6 +155,7 @@ def init(
150155
self.shm_prompt_ids.arr[0 : len(prompt_ids)] = prompt_ids
151156
self.mtp_accepted_token_num = 0
152157
self._mtp_step = get_env_start_args().mtp_step
158+
self.stop_str_matched = False
153159

154160
self.post_init()
155161

@@ -207,7 +213,7 @@ def can_release(self):
207213
ref_count_ok = self.ref_count == 1
208214
can_released_mark = self.can_released_mark
209215

210-
if self.is_aborted and can_released_mark and ref_count_ok:
216+
if (self.is_aborted or self.stop_str_matched) and can_released_mark and ref_count_ok:
211217
return True
212218

213219
if self.finish_status.is_finished() and can_released_mark and ref_count_ok and self.out_tokens_queue.is_empty():

lightllm/server/detokenization/decode_req.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ def init_token_healing_prefix_str(self, token_id_to_token: Dict[int, str], token
3636
return
3737

3838
def need_detoken(self):
39-
if (not self.req.is_aborted) and len(self.output_ids) < self.req.candetoken_out_len:
39+
if (
40+
(not self.req.is_aborted)
41+
and (not self.req.stop_str_matched)
42+
and len(self.output_ids) < self.req.candetoken_out_len
43+
):
4044
return True
4145
return False
4246

@@ -55,6 +59,9 @@ def get_decode_tokens(self):
5559
def can_set_release_mark(self):
5660
if self.req.is_aborted:
5761
return True
62+
if self.req.stop_str_matched:
63+
# httpserver那里必须先处理完请求, 这里才能释放
64+
return self.req.out_tokens_queue.is_empty()
5865
if (
5966
self.req.finish_status.is_finished()
6067
and self.req.candetoken_out_len == len(self.output_ids)

lightllm/server/detokenization/manager.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,32 @@ def handle_loop(self):
101101
logger.exception(f"detoken process has exception {str(e)}")
102102
return
103103

104+
def _stop_sequences_str_matched(self, decode_req, tokenizer):
105+
stop_sequences_str = (
106+
decode_req.req.sample_params.stop_sequences.to_string()
107+
if decode_req.req.sample_params.stop_sequences
108+
else []
109+
)
110+
if not stop_sequences_str or tokenizer is None:
111+
return False
112+
113+
max_stop_str_len = max(len(stop_str) for stop_str in stop_sequences_str) if stop_sequences_str else 0
114+
if max_stop_str_len == 0:
115+
return False
116+
117+
output_len = len(decode_req.output_ids)
118+
tail_token_len = min(decode_req.req.input_len + output_len, max_stop_str_len + 10) # +10 for safety
119+
if tail_token_len > 0:
120+
tail_token_ids = decode_req.req.shm_prompt_ids.arr[
121+
(decode_req.req.input_len + output_len - tail_token_len) : (decode_req.req.input_len + output_len)
122+
]
123+
tail_str = tokenizer.decode(tail_token_ids, skip_special_tokens=False)
124+
for stop_str in stop_sequences_str:
125+
if stop_str in tail_str:
126+
logger.info(f"Found stop sequence in tail: stop_str='{stop_str}', " f"tail_str='{tail_str}'")
127+
return True
128+
return False
129+
104130
def gen_token_out(self):
105131
exist_need_detoken = False
106132
exist_decode = False
@@ -111,6 +137,9 @@ def gen_token_out(self):
111137
special = new_token_id in self.all_special_ids
112138
count_output_tokens = len(decode_req.output_ids)
113139

140+
if decode_req.req.stop_str_matched:
141+
continue
142+
114143
exist_decode = True
115144
new_text = decode_token(
116145
self.tokenizer,
@@ -131,6 +160,13 @@ def gen_token_out(self):
131160
logger.error(
132161
f"error token healing state, prefix_str {decode_req.prefix_str} new_text {new_text}"
133162
)
163+
164+
# 停止字符串匹配
165+
if not decode_req.req.finish_status.is_stoped() and self._stop_sequences_str_matched(
166+
decode_req, self.tokenizer
167+
):
168+
decode_req.req.stop_str_matched = True
169+
134170
decode_req.req.out_tokens_queue.push(new_text, src_index, special, count_output_tokens)
135171

136172
if decode_req.need_detoken():

lightllm/server/httpserver/manager.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,12 @@ async def handle_loop(self):
661661
for _ in range(read_token_count):
662662
if not req.out_tokens_queue.is_empty():
663663

664-
text, src_index, special, count_output_tokens = req.out_tokens_queue.peek()
664+
(
665+
text,
666+
src_index,
667+
special,
668+
count_output_tokens,
669+
) = req.out_tokens_queue.peek()
665670
req.cumlogprob += float(req.shm_logprobs.arr[src_index])
666671
metadata = {
667672
"id": int(req.shm_prompt_ids.arr[src_index]),
@@ -679,10 +684,14 @@ async def handle_loop(self):
679684

680685
req.out_tokens_queue.pop_no_ret()
681686

682-
if req.finish_token_index != src_index:
687+
if not req.stop_str_matched and req.finish_token_index != src_index:
683688
token_list.append((req_id, text, metadata, FinishStatus()))
684689
else:
685-
finish_status = FinishStatus(req.finish_status.status)
690+
finish_status = FinishStatus(
691+
req.finish_status.FINISHED_STOP
692+
if req.stop_str_matched
693+
else req.finish_status.status
694+
)
686695
token_list.append((req_id, text, metadata, finish_status))
687696
else:
688697
break

lightllm/server/router/manager.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .batch import Batch, Req
1616
from .model_infer.model_rpc import start_model_process, ModelRpcClient
1717
from .req_queue import build_req_queue
18-
from lightllm.server.core.objs.io_objs import GroupReqIndexes, AbortedReqCmd
18+
from lightllm.server.core.objs.io_objs import GroupReqIndexes, AbortedReqCmd, StopStrMatchedReqCmd
1919
from lightllm.server.core.objs import ShmReqManager, StartArgs
2020
from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient
2121
from .shm_reqs_io_buffer import ShmReqsIOBuffer
@@ -277,8 +277,11 @@ async def _step(self):
277277

278278
self._filter_reqs_from_running_batch()
279279
aborted_reqs = self._get_aborted_reqs_from_running_batch()
280+
stop_str_matched_reqs = self._get_stop_str_reqs_from_running_batch()
280281
if aborted_reqs:
281282
await self._aborted_reqs(aborted_reqs=aborted_reqs)
283+
if stop_str_matched_reqs:
284+
await self._stop_str_matched_reqs(stop_str_matched_reqs=stop_str_matched_reqs)
282285
return
283286

284287
async def _add_batch(self, batch: Batch):
@@ -301,6 +304,15 @@ async def _aborted_reqs(self, aborted_reqs: List[Req]):
301304
self.shm_reqs_io_buffer.set_ready()
302305
return
303306

307+
async def _stop_str_matched_reqs(self, stop_str_matched_reqs: List[Req]):
308+
cmds = [StopStrMatchedReqCmd(req_id=r.request_id) for r in stop_str_matched_reqs]
309+
while not self.shm_reqs_io_buffer.is_empty():
310+
await asyncio.sleep(0.02)
311+
312+
self.shm_reqs_io_buffer.write_obj(cmds)
313+
self.shm_reqs_io_buffer.set_ready()
314+
return
315+
304316
def _add_new_batch_to_running_batch(self, new_batch: Batch):
305317
if self.running_batch is None:
306318
self.running_batch = new_batch
@@ -325,6 +337,15 @@ def _get_aborted_reqs_from_running_batch(self) -> List[Req]:
325337
ans.append(req)
326338
return ans
327339

340+
def _get_stop_str_reqs_from_running_batch(self) -> List[Req]:
341+
ans = []
342+
if self.running_batch is None:
343+
return ans
344+
for req in self.running_batch.reqs:
345+
if req.stop_str_matched:
346+
ans.append(req)
347+
return ans
348+
328349
def _get_paused_req_num(self) -> int:
329350
if self.running_batch is None:
330351
return 0

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,6 @@ def _init_all_state(self):
319319
g_infer_context.req_manager.req_sampling_params_manager.init_req_sampling_params(self)
320320

321321
self.stop_sequences = self.sampling_param.shm_param.stop_sequences.to_list()
322-
self.stop_sequences_str = self.sampling_param.shm_param.stop_sequences.to_string()
323322
# token healing mode 才被使用的管理对象
324323
if self.shm_req.prefix_token_ids.size != 0:
325324
self.prefix_token_ids = self.shm_req.prefix_token_ids.get_token_ids()
@@ -380,10 +379,8 @@ def update_mtp_accepted_token_num(self, accept_token_num: int):
380379
def get_last_gen_token(self):
381380
return self.shm_req.shm_prompt_ids.arr[self.shm_req.input_len + self.cur_output_len - 1]
382381

383-
def update_finish_status(self, eos_ids, output_len: int, tokenizer=None):
384-
if self._stop_sequences_matched(output_len=output_len) or self._stop_sequences_str_matched(
385-
tokenizer, output_len
386-
):
382+
def update_finish_status(self, eos_ids, output_len: int):
383+
if self._stop_sequences_matched(output_len=output_len):
387384
self.finish_status.set_status(FinishStatus.FINISHED_STOP)
388385
elif (
389386
output_len > 0
@@ -408,26 +405,6 @@ def _stop_sequences_matched(self, output_len: int):
408405
return True
409406
return False
410407

411-
def _stop_sequences_str_matched(self, tokenizer, output_len):
412-
if not self.stop_sequences_str or tokenizer is None:
413-
return False
414-
415-
max_stop_str_len = max(len(stop_str) for stop_str in self.stop_sequences_str) if self.stop_sequences_str else 0
416-
if max_stop_str_len == 0:
417-
return False
418-
419-
tail_token_len = min(self.shm_req.input_len + output_len, max_stop_str_len + 10) # +10 for safety
420-
if tail_token_len > 0:
421-
tail_token_ids = self.shm_req.shm_prompt_ids.arr[
422-
(self.shm_req.input_len + output_len - tail_token_len) : (self.shm_req.input_len + output_len)
423-
]
424-
tail_str = tokenizer.decode(tail_token_ids, skip_special_tokens=False)
425-
for stop_str in self.stop_sequences_str:
426-
if stop_str in tail_str:
427-
logger.info(f"Found stop sequence in tail: stop_str='{stop_str}', tail_str='{tail_str}'")
428-
return True
429-
return False
430-
431408
def prefill_need_token_num(self, is_chuncked_prefill: bool):
432409
if is_chuncked_prefill:
433410
input_token_ids = self.get_chuncked_input_token_ids()
@@ -506,7 +483,6 @@ def handle(
506483
eos_ids: List[int],
507484
extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]],
508485
is_master_in_dp: bool,
509-
tokenizer=None,
510486
):
511487
if self.output_len <= 0:
512488
return
@@ -528,7 +504,7 @@ def handle(
528504
return
529505

530506
# 更新判断请求的 finished 状态
531-
req_obj.update_finish_status(eos_ids=eos_ids, output_len=self.output_len, tokenizer=tokenizer)
507+
req_obj.update_finish_status(eos_ids=eos_ids, output_len=self.output_len)
532508

533509
if extra_post_req_handle_func is not None:
534510
extra_post_req_handle_func(req_obj, next_token_id, next_token_logprob)

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,18 @@
1919
from lightllm.utils.dist_utils import init_distributed_env
2020
from lightllm.utils.envs_utils import get_unique_server_name
2121
from lightllm.server.core.objs import ShmReqManager, StartArgs
22-
from lightllm.server.core.objs.io_objs import AbortedReqCmd
22+
from lightllm.server.core.objs.io_objs import AbortedReqCmd, StopStrMatchedReqCmd
2323
from lightllm.server.router.model_infer.infer_batch import g_infer_context
2424
from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager
2525
from lightllm.utils.dist_utils import get_global_rank, get_global_world_size, get_dp_size
2626
from lightllm.utils.dist_utils import get_dp_world_size, get_global_dp_rank, get_current_rank_in_dp
2727
from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node, get_node_world_size
2828
from lightllm.utils.dist_utils import get_dp_rank_in_node, create_new_group_for_current_node
29-
from lightllm.utils.envs_utils import get_env_start_args, enable_stop_string_match
29+
from lightllm.utils.envs_utils import get_env_start_args
3030
from lightllm.distributed import dist_group_manager
3131
from lightllm.server.router.shm_reqs_io_buffer import ShmReqsIOBuffer
3232
from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack
3333
from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel
34-
from lightllm.server.tokenizer import get_tokenizer
3534

3635

3736
class ModeBackend:
@@ -322,6 +321,12 @@ def _read_reqs_buffer_and_init_reqs(self):
322321
if obj.req_id in g_infer_context.requests_mapping:
323322
req: InferReq = g_infer_context.requests_mapping[obj.req_id]
324323
req.infer_aborted = True
324+
elif isinstance(cmds[0], StopStrMatchedReqCmd):
325+
for obj in cmds:
326+
obj: StopStrMatchedReqCmd = obj
327+
if obj.req_id in g_infer_context.requests_mapping:
328+
req: InferReq = g_infer_context.requests_mapping[obj.req_id]
329+
req.infer_aborted = True
325330
else:
326331
self._init_reqs(reqs=cmds)
327332
return
@@ -507,14 +512,6 @@ def _post_handle(
507512
extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于
508513
约束输出等模式,设置自己请求内部的状态机的状态,并添加额外的停止判定条件等。
509514
"""
510-
if enable_stop_string_match():
511-
if not hasattr(self, "tokenizer"):
512-
self.tokenizer = get_tokenizer(
513-
self.args.model_dir, self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code
514-
)
515-
else:
516-
self.tokenizer = None
517-
518515
for req_obj, next_token_id, next_token_logprob, pack in zip(
519516
run_reqs, next_token_ids, next_token_logprobs, run_reqs_update_packs
520517
):
@@ -526,7 +523,6 @@ def _post_handle(
526523
eos_ids=self.eos_id,
527524
extra_post_req_handle_func=extra_post_req_handle_func,
528525
is_master_in_dp=self.is_master_in_dp,
529-
tokenizer=self.tokenizer,
530526
)
531527

532528
g_infer_context.req_manager.req_sampling_params_manager.update_reqs_token_counter(

0 commit comments

Comments
 (0)