Skip to content

Commit 817a24a

Browse files
committed
fix
1 parent 3b6de91 commit 817a24a

File tree

2 files changed

+28
-29
lines changed

2 files changed

+28
-29
lines changed

lightllm/server/detokenization/decode_req.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import os
22
from typing import List, Dict
33
from lightllm.server.core.objs import Req
4+
from lightllm.utils.log_utils import init_logger
5+
6+
logger = init_logger(__name__)
7+
48

59
LIGHTLLM_DECODE_PREFIX_LENGTH = int(os.getenv("LIGHTLLM_DECODE_PREFIX_LENGTH", 5))
610

@@ -15,6 +19,7 @@ def __init__(
1519
self.group_req_id = req.group_req_id
1620
self.prompt_ids = req.shm_prompt_ids.arr[0 : req.input_len].tolist()
1721
self.output_ids = []
22+
self.output_strs = []
1823
self.prefix_offset = max(len(self.prompt_ids) - LIGHTLLM_DECODE_PREFIX_LENGTH, 0)
1924

2025
if is_pd_decode_mode:
@@ -26,6 +31,8 @@ def __init__(
2631
self.req = req
2732
self.input_len = self.req.input_len
2833
self.prefix_str = ""
34+
self.stop_strs: List[str] = self.req.sample_params.stop_sequences.to_strings()
35+
self.stop_str_max_len = max([len(e) for e in self.stop_strs])
2936

3037
def init_token_healing_prefix_str(self, token_id_to_token: Dict[int, str], tokenizer):
3138
tokens = [token_id_to_token[token_id] for token_id in self.req.prefix_token_ids.get_token_ids()]
@@ -35,6 +42,24 @@ def init_token_healing_prefix_str(self, token_id_to_token: Dict[int, str], token
3542
self.prefix_str = ""
3643
return
3744

45+
def stop_sequences_str_match(self) -> bool:
46+
stop_strs = self.stop_strs
47+
if not stop_strs or self.stop_str_max_len == 0:
48+
return False
49+
50+
tail_token_len = self.stop_str_max_len + 10 # 10 for safety
51+
tail_token_strs = self.output_strs[-tail_token_len:]
52+
tail_str = "".join(tail_token_strs)
53+
54+
for stop_str in stop_strs:
55+
if stop_str in tail_str:
56+
logger.info(
57+
f"req_id {self.request_id} Found stop sequence in tail: stop_str='{stop_str}', "
58+
f"tail_str='{tail_str}'"
59+
)
60+
return True
61+
return False
62+
3863
def need_detoken(self):
3964
if (
4065
(not self.req.is_aborted)

lightllm/server/detokenization/manager.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -101,32 +101,6 @@ def handle_loop(self):
101101
logger.exception(f"detoken process has exception {str(e)}")
102102
return
103103

104-
def _stop_sequences_str_matched(self, decode_req, tokenizer):
105-
stop_sequences_str = (
106-
decode_req.req.sample_params.stop_sequences.to_string()
107-
if decode_req.req.sample_params.stop_sequences
108-
else []
109-
)
110-
if not stop_sequences_str or tokenizer is None:
111-
return False
112-
113-
max_stop_str_len = max(len(stop_str) for stop_str in stop_sequences_str) if stop_sequences_str else 0
114-
if max_stop_str_len == 0:
115-
return False
116-
117-
output_len = len(decode_req.output_ids)
118-
tail_token_len = min(decode_req.req.input_len + output_len, max_stop_str_len + 10) # +10 for safety
119-
if tail_token_len > 0:
120-
tail_token_ids = decode_req.req.shm_prompt_ids.arr[
121-
(decode_req.req.input_len + output_len - tail_token_len) : (decode_req.req.input_len + output_len)
122-
]
123-
tail_str = tokenizer.decode(tail_token_ids, skip_special_tokens=False)
124-
for stop_str in stop_sequences_str:
125-
if stop_str in tail_str:
126-
logger.info(f"Found stop sequence in tail: stop_str='{stop_str}', " f"tail_str='{tail_str}'")
127-
return True
128-
return False
129-
130104
def gen_token_out(self):
131105
exist_need_detoken = False
132106
exist_decode = False
@@ -161,10 +135,10 @@ def gen_token_out(self):
161135
f"error token healing state, prefix_str {decode_req.prefix_str} new_text {new_text}"
162136
)
163137

138+
decode_req.output_strs.append(new_text)
139+
164140
# 停止字符串匹配
165-
if not decode_req.req.finish_status.is_stoped() and self._stop_sequences_str_matched(
166-
decode_req, self.tokenizer
167-
):
141+
if decode_req.stop_sequences_str_match():
168142
decode_req.req.stop_str_matched = True
169143

170144
decode_req.req.out_tokens_queue.push(new_text, src_index, special, count_output_tokens)

0 commit comments

Comments
 (0)