Skip to content

Commit 120c833

Browse files
authored
feat: add stop string matching (#969)
1 parent 86d262a commit 120c833

File tree

11 files changed

+189
-47
lines changed

11 files changed

+189
-47
lines changed

lightllm/server/api_openai.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,9 @@ async def process_single_prompt(prompt: Union[str, List[int]], prompt_index: int
426426
prompt, individual_sampling_params, multimodal_params, request=raw_request
427427
)
428428

429-
return await _collect_generation_results(generator, request, prompt_str, prompt_index)
429+
return await _collect_generation_results(
430+
generator, request, prompt_str, prompt_index, individual_sampling_params
431+
)
430432

431433
tasks = [asyncio.create_task(process_single_prompt(prompt, i)) for i, prompt in enumerate(prompts)]
432434

@@ -485,7 +487,9 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
485487
return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks)
486488

487489

488-
async def _collect_generation_results(generator, request: CompletionRequest, prompt: str, prompt_index: int):
490+
async def _collect_generation_results(
491+
generator, request: CompletionRequest, prompt: str, prompt_index: int, sampling_params: SamplingParams
492+
):
489493
final_output = []
490494
count_output_tokens = 0
491495
finish_reason = None
@@ -516,9 +520,20 @@ async def _collect_generation_results(generator, request: CompletionRequest, pro
516520
finish_reason = finish_status.get_finish_reason()
517521
prompt_tokens = metadata["prompt_tokens"]
518522

523+
# 处理停止序列剔除
524+
final_text = "".join(final_output)
525+
if finish_reason == "stop" and sampling_params.stop_sequences.size > 0:
526+
valid_stop_strings = sampling_params.stop_sequences.to_strings()
527+
for stop_str in valid_stop_strings:
528+
stop_index = final_text.rfind(stop_str, max(0, len(final_text) - len(stop_str) - 20), len(final_text))
529+
if stop_index != -1:
530+
logger.debug(f"removed stop sequence in tail: '{final_text[stop_index:]}'")
531+
final_text = final_text[:stop_index]
532+
break
533+
519534
return {
520535
"index": prompt_index,
521-
"text": "".join(final_output),
536+
"text": final_text,
522537
"finish_reason": finish_reason,
523538
"prompt_tokens": prompt_tokens,
524539
"completion_tokens": count_output_tokens,
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .group_req import GroupReqIndexes, GroupReqObjs, AbortedReqCmd
1+
from .group_req import GroupReqIndexes, GroupReqObjs, AbortedReqCmd, StopStrMatchedReqCmd

lightllm/server/core/objs/io_objs/group_req.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,8 @@ def to_group_req_index(self):
3131
@dataclass
3232
class AbortedReqCmd:
3333
req_id: int
34+
35+
36+
@dataclass
37+
class StopStrMatchedReqCmd:
38+
req_id: int

lightllm/server/core/objs/req.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ def get_status(self):
3232
def is_finished(self):
3333
return self.FINISHED_STOP <= self.status <= self.FINISHED_LENGTH
3434

35+
def is_stopped(self):
36+
return self.status == self.FINISHED_STOP
37+
3538
def get_finish_reason(self):
3639
if self.status == self.FINISHED_STOP:
3740
return "stop"
@@ -74,10 +77,8 @@ class Req(ctypes.Structure):
7477
("prompt_cache_len", ctypes.c_int), # 用于记录prompt cache 的命中长度,用于统计
7578
("is_paused", ctypes.c_bool), # 标记一个Req因为显存资源管理的原因被临时暂停了。
7679
("finish_status", FinishStatus),
80+
# 这个标记变量是http_server 写入,其他进程读取,用于标记该请求是否因为断网被aborted。
7781
("is_aborted", ctypes.c_bool),
78-
# 这个标记变量是router进程读取到is_aborted信息后,router 进程标记该请求已经被abort处理
79-
# 等待推理进程处理,防止router进程反复给推理进程发送abort指令。
80-
("router_aborted", ctypes.c_bool),
8182
# 当FinishStatus 是正常结束状态时,finish_token_index 用于标识结束的
8283
# token 的index位置
8384
("finish_token_index", ctypes.c_int),
@@ -97,6 +98,12 @@ class Req(ctypes.Structure):
9798
("mtp_accepted_token_num", ctypes.c_int),
9899
# mtp_step 保存一个mtp使用的常量参数,用于快速访问,不会被外部输入初始化
99100
("_mtp_step", ctypes.c_int),
101+
# stop_str_matched 用于判断停止字符串是否匹配成功, detokenization 进程写入,router 进程读取
102+
# 然后router发停止命令给推理进程,推理进程停止输出
103+
("stop_str_matched", ctypes.c_bool),
104+
# 当 stop_str_matched 条件满足的时候,对应的最后一个生成 token 所在的index位置。
105+
# 该变量为 detokenization 进程写入,http_server 读取
106+
("stop_str_matched_token_index", ctypes.c_int),
100107
]
101108

102109
def get_str(self):
@@ -124,7 +131,6 @@ def init(
124131
self.is_paused = False
125132
self.finish_status = FinishStatus()
126133
self.is_aborted = False
127-
self.router_aborted = False
128134
self.shm_infer_released = False
129135
self.shm_cur_kv_len = 0
130136
self.shm_cur_output_len = 0
@@ -150,6 +156,8 @@ def init(
150156
self.shm_prompt_ids.arr[0 : len(prompt_ids)] = prompt_ids
151157
self.mtp_accepted_token_num = 0
152158
self._mtp_step = get_env_start_args().mtp_step
159+
self.stop_str_matched = False
160+
self.stop_str_matched_token_index = -1
153161

154162
self.post_init()
155163

@@ -210,7 +218,9 @@ def can_release(self):
210218
if self.is_aborted and can_released_mark and ref_count_ok:
211219
return True
212220

213-
if self.finish_status.is_finished() and can_released_mark and ref_count_ok and self.out_tokens_queue.is_empty():
221+
ok_finished_gen_req = self.finish_status.is_finished() or self.stop_str_matched
222+
223+
if ok_finished_gen_req and can_released_mark and ref_count_ok and self.out_tokens_queue.is_empty():
214224
return True
215225

216226
return False

lightllm/server/core/objs/sampling_params.py

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import ctypes
3-
from typing import List, Tuple, Union
3+
from typing import Optional, List, Tuple, Union
44
from transformers import GenerationConfig
55
from lightllm.server.req_id_generator import MAX_BEST_OF
66

@@ -10,6 +10,7 @@
1010

1111
# 从环境变量获取最大长度限制
1212
STOP_SEQUENCE_MAX_LENGTH = int(os.getenv("LIGHTLLM_STOP_SEQUENCE_MAX_LENGTH", 256))
13+
STOP_SEQUENCE_STR_MAX_LENGTH = int(os.getenv("LIGHTLLM_STOP_SEQUENCE_STR_MAX_LENGTH", 256))
1314
ALLOWED_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_ALLOWED_TOKEN_IDS_MAX_LENGTH", 256))
1415
MAX_STOP_SEQUENCES = int(os.getenv("LIGHTLLM_MAX_STOP_SEQUENCES", 10))
1516
REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048))
@@ -22,17 +23,30 @@ class StopSequence(ctypes.Structure):
2223
_fields_ = [
2324
("sequence", ctypes.c_int * STOP_SEQUENCE_MAX_LENGTH),
2425
("size", ctypes.c_int),
26+
("sequence_str", ctypes.c_char * STOP_SEQUENCE_STR_MAX_LENGTH),
27+
("sequence_str_len", ctypes.c_int),
2528
]
2629

27-
def initialize(self, sequence: List[int]):
30+
def initialize(self, sequence: List[int], sequence_str: Optional[str] = None):
2831
self.size = len(sequence)
2932
assert self.size <= STOP_SEQUENCE_MAX_LENGTH, "stop token length too long."
3033
assert all(isinstance(e, int) for e in sequence), "all must be int"
3134
self.sequence[: self.size] = sequence[:]
3235

33-
def to_list(self):
36+
if sequence_str is not None:
37+
sequence_str_bytes = sequence_str.encode("utf-8")
38+
assert len(sequence_str_bytes) < STOP_SEQUENCE_STR_MAX_LENGTH, "stop sequence string too long."
39+
self.sequence_str = sequence_str_bytes
40+
self.sequence_str_len = len(sequence_str_bytes)
41+
else:
42+
self.sequence_str_len = 0
43+
44+
def to_list(self) -> List[int]:
3445
return list(self.sequence[0 : self.size])
3546

47+
def to_string(self) -> str:
48+
return bytes(self.sequence_str[0 : self.sequence_str_len]).decode("utf-8")
49+
3650

3751
class StopSequenceGroups(ctypes.Structure):
3852
_pack_ = 4
@@ -41,40 +55,52 @@ class StopSequenceGroups(ctypes.Structure):
4155
("size", ctypes.c_int),
4256
]
4357

44-
def initialize(self, stop_sequences: Union[str, List], tokenizer):
58+
def initialize(self, stop_sequences: Union[str, List[Union[List[int], str]]], tokenizer):
59+
if stop_sequences is None:
60+
stop_sequences = []
61+
elif isinstance(stop_sequences, str):
62+
stop_sequences = [stop_sequences]
63+
4564
groups: List[List[int]] = self.stop_sentences_to_token_ids(stop_sequences, tokenizer)
4665
self.size = len(groups)
4766
assert self.size <= MAX_STOP_SEQUENCES, "Too many stop sequence groups."
48-
for group_idx in range(self.size):
49-
self.groups[group_idx].initialize(groups[group_idx])
5067

51-
def stop_sentences_to_token_ids(self, stop_sequences, tokenizer):
52-
if stop_sequences is None:
53-
stop_sequences = []
54-
else:
55-
if isinstance(stop_sequences, str):
56-
stop_sequences = [stop_sequences]
57-
58-
new_stop_sequences = []
59-
for stop_info in stop_sequences:
60-
if isinstance(stop_info, str):
61-
stop_str_ids = self._stop_str_to_token_ids(stop_info, tokenizer)
62-
if stop_str_ids is not None and len(stop_str_ids) > 0:
63-
new_stop_sequences.append(stop_str_ids)
64-
if isinstance(stop_info, list):
65-
if all(isinstance(x, int) for x in stop_info):
66-
if len(stop_info) > 0:
67-
new_stop_sequences.append(stop_info)
68-
stop_sequences = new_stop_sequences
69-
return stop_sequences
70-
71-
def _stop_str_to_token_ids(self, stop_str: str, tokenizer):
68+
for group_idx in range(self.size):
69+
if isinstance(stop_sequences[group_idx], str):
70+
self.groups[group_idx].initialize(groups[group_idx], sequence_str=stop_sequences[group_idx])
71+
else:
72+
self.groups[group_idx].initialize(groups[group_idx])
73+
74+
def stop_sentences_to_token_ids(self, stop_sequences: List[Union[List[int], str]], tokenizer) -> List[List[int]]:
75+
new_stop_sequences = []
76+
for stop_info in stop_sequences:
77+
if isinstance(stop_info, str):
78+
stop_str_ids = self._stop_str_to_token_ids(stop_info, tokenizer)
79+
if stop_str_ids is not None and len(stop_str_ids) > 0:
80+
new_stop_sequences.append(stop_str_ids)
81+
if isinstance(stop_info, list):
82+
if all(isinstance(x, int) for x in stop_info):
83+
if len(stop_info) > 0:
84+
new_stop_sequences.append(stop_info)
85+
else:
86+
assert False, "stop_sequences item must be type List[int] when it is a list."
87+
return new_stop_sequences
88+
89+
def _stop_str_to_token_ids(self, stop_str: str, tokenizer) -> List[int]:
7290
stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
7391
return stop_str_ids
7492

75-
def to_list(self):
93+
def to_list(self) -> List[List[int]]:
7694
return [self.groups[i].to_list() for i in range(self.size)]
7795

96+
def to_strings(self) -> List[str]:
97+
# 降序匹配,在出现"\n\n"和"\n"情况时,优先匹配“\n\n”
98+
return sorted(
99+
[self.groups[i].to_string() for i in range(self.size) if self.groups[i].sequence_str_len > 0],
100+
key=len,
101+
reverse=True,
102+
)
103+
78104

79105
class RegularConstraint(ctypes.Structure):
80106
_pack_ = 4

lightllm/server/detokenization/decode_req.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import os
22
from typing import List, Dict
33
from lightllm.server.core.objs import Req
4+
from lightllm.utils.log_utils import init_logger
5+
6+
logger = init_logger(__name__)
7+
48

59
LIGHTLLM_DECODE_PREFIX_LENGTH = int(os.getenv("LIGHTLLM_DECODE_PREFIX_LENGTH", 5))
610

@@ -15,6 +19,7 @@ def __init__(
1519
self.group_req_id = req.group_req_id
1620
self.prompt_ids = req.shm_prompt_ids.arr[0 : req.input_len].tolist()
1721
self.output_ids = []
22+
self.output_strs = []
1823
self.prefix_offset = max(len(self.prompt_ids) - LIGHTLLM_DECODE_PREFIX_LENGTH, 0)
1924

2025
if is_pd_decode_mode:
@@ -26,6 +31,9 @@ def __init__(
2631
self.req = req
2732
self.input_len = self.req.input_len
2833
self.prefix_str = ""
34+
self.stop_strs: List[str] = self.req.sample_params.stop_sequences.to_strings()
35+
# to_strings()已经做了倒序排列,第一个元素就是最长字符串
36+
self.stop_str_max_len = len(self.stop_strs[0]) if self.stop_strs else 0
2937

3038
def init_token_healing_prefix_str(self, token_id_to_token: Dict[int, str], tokenizer):
3139
tokens = [token_id_to_token[token_id] for token_id in self.req.prefix_token_ids.get_token_ids()]
@@ -35,8 +43,30 @@ def init_token_healing_prefix_str(self, token_id_to_token: Dict[int, str], token
3543
self.prefix_str = ""
3644
return
3745

46+
def stop_sequences_str_match(self) -> bool:
47+
stop_strs = self.stop_strs
48+
if not stop_strs or self.stop_str_max_len == 0:
49+
return False
50+
51+
tail_token_len = self.stop_str_max_len + 10 # 10 for safety
52+
tail_token_strs = self.output_strs[-tail_token_len:]
53+
tail_str = "".join(tail_token_strs)
54+
55+
for stop_str in stop_strs:
56+
if stop_str in tail_str:
57+
logger.debug(
58+
f"req_id {self.request_id} Found stop sequence in tail: stop_str='{stop_str}', "
59+
f"tail_str='{tail_str}'"
60+
)
61+
return True
62+
return False
63+
3864
def need_detoken(self):
39-
if (not self.req.is_aborted) and len(self.output_ids) < self.req.candetoken_out_len:
65+
if (
66+
(not self.req.is_aborted)
67+
and (not self.req.stop_str_matched)
68+
and len(self.output_ids) < self.req.candetoken_out_len
69+
):
4070
return True
4171
return False
4272

@@ -55,6 +85,8 @@ def get_decode_tokens(self):
5585
def can_set_release_mark(self):
5686
if self.req.is_aborted:
5787
return True
88+
if self.req.stop_str_matched:
89+
return True
5890
if (
5991
self.req.finish_status.is_finished()
6092
and self.req.candetoken_out_len == len(self.output_ids)

lightllm/server/detokenization/manager.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ def gen_token_out(self):
105105
exist_need_detoken = False
106106
exist_decode = False
107107
for decode_req in self.req_id_to_out.values():
108+
# 已经满足停止字符串停止条件,则不再处理后续生成 token
109+
if decode_req.req.stop_str_matched:
110+
continue
111+
108112
if decode_req.need_detoken() and not decode_req.out_queue_is_full():
109113
new_token_id, src_index = decode_req.get_next_token_id_and_index()
110114
decode_req.output_ids.append(new_token_id)
@@ -131,6 +135,14 @@ def gen_token_out(self):
131135
logger.error(
132136
f"error token healing state, prefix_str {decode_req.prefix_str} new_text {new_text}"
133137
)
138+
139+
decode_req.output_strs.append(new_text)
140+
141+
# 停止字符串匹配
142+
if not decode_req.req.finish_status.is_stopped() and decode_req.stop_sequences_str_match():
143+
decode_req.req.stop_str_matched_token_index = src_index
144+
decode_req.req.stop_str_matched = True
145+
134146
decode_req.req.out_tokens_queue.push(new_text, src_index, special, count_output_tokens)
135147

136148
if decode_req.need_detoken():

lightllm/server/httpserver/manager.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -679,10 +679,18 @@ async def handle_loop(self):
679679

680680
req.out_tokens_queue.pop_no_ret()
681681

682-
if req.finish_token_index != src_index:
682+
finished_token_index = (
683+
req.stop_str_matched_token_index if req.stop_str_matched else req.finish_token_index
684+
)
685+
686+
if finished_token_index != src_index:
683687
token_list.append((req_id, text, metadata, FinishStatus()))
684688
else:
685-
finish_status = FinishStatus(req.finish_status.status)
689+
if req.stop_str_matched:
690+
finish_status = FinishStatus(FinishStatus.FINISHED_STOP)
691+
else:
692+
finish_status = FinishStatus(req.finish_status.status)
693+
686694
token_list.append((req_id, text, metadata, finish_status))
687695
else:
688696
break

0 commit comments

Comments
 (0)