Skip to content

Commit aae3b86

Browse files
author
niushengxiao
committed
feat: move stop string matching to detokenization
1 parent c9e31b3 commit aae3b86

File tree

15 files changed

+120
-59
lines changed

15 files changed

+120
-59
lines changed

lightllm/server/api_openai.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@ async def _collect_generation_results(
539539
earliest_stop_index = actual_stop_index
540540

541541
if earliest_stop_index < len(final_text):
542+
logger.info(f"removed stop sequence in tail: '{final_text[earliest_stop_index:]}'")
542543
final_text = final_text[:earliest_stop_index]
543544

544545
return {
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .group_req import GroupReqIndexes, GroupReqObjs, AbortedReqCmd
1+
from .group_req import GroupReqIndexes, GroupReqObjs, AbortedReqCmd, StopStrMatchedReqCmd

lightllm/server/core/objs/io_objs/group_req.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,8 @@ def to_group_req_index(self):
3131
@dataclass
3232
class AbortedReqCmd:
3333
req_id: int
34+
35+
36+
@dataclass
37+
class StopStrMatchedReqCmd:
38+
req_id: int

lightllm/server/core/objs/out_token_circlequeue.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,17 @@ class QueueItem(ctypes.Structure):
1414
("special", ctypes.c_bool),
1515
("count_output_tokens", ctypes.c_int),
1616
("src_index", ctypes.c_int), # 在源token队列的索引位置
17+
("is_stop_str_matched", ctypes.c_bool), # 强制停止所有推理并让客户端返回
1718
]
1819

1920
def __init__(self):
2021
self.data_len = 0
2122
self.src_index = -1
2223
self.special = False
2324
self.count_output_tokens = -1
25+
self.is_stop_str_matched = False
2426

25-
def set(self, token_str: str, src_index: int, special: bool, count_output_tokens: int):
27+
def set(self, token_str: str, src_index: int, special: bool, count_output_tokens: int, is_stop_str_matched: bool):
2628
str_bytes = token_str.encode("utf-8")
2729
assert (
2830
len(str_bytes) <= LIGHTLLM_TOKEN_MAX_BYTES
@@ -32,6 +34,7 @@ def set(self, token_str: str, src_index: int, special: bool, count_output_tokens
3234
self.src_index = src_index
3335
self.special = special
3436
self.count_output_tokens = count_output_tokens
37+
self.is_stop_str_matched = is_stop_str_matched
3538
return
3639

3740
def get(self):
@@ -40,6 +43,7 @@ def get(self):
4043
self.src_index,
4144
self.special,
4245
self.count_output_tokens,
46+
self.is_stop_str_matched,
4347
)
4448

4549

@@ -62,13 +66,13 @@ def is_empty(self):
6266
def is_full(self):
6367
return (self.tail + 1) % LIGHTLLM_OUT_TOKEN_QUEUE_SIZE == self.head
6468

65-
def push(self, token_str: str, src_index: int, special: bool, count_output_tokens: int):
69+
def push(self, token_str: str, src_index: int, special: bool, count_output_tokens: int, is_stop_str_matched: bool):
6670
if self.is_full():
6771
raise Exception("Queue is full")
6872

6973
# 添加元素
7074
item: QueueItem = self.items[self.tail]
71-
item.set(token_str, src_index, special, count_output_tokens)
75+
item.set(token_str, src_index, special, count_output_tokens, is_stop_str_matched)
7276

7377
# 更新尾部
7478
self.tail = (self.tail + 1) % LIGHTLLM_OUT_TOKEN_QUEUE_SIZE
@@ -85,7 +89,7 @@ def pop(self) -> Tuple[str, int, bool, int]:
8589
self.head = (self.head + 1) % LIGHTLLM_OUT_TOKEN_QUEUE_SIZE
8690
return result
8791

88-
def peek(self) -> Tuple[str, int, bool, int]:
92+
def peek(self) -> Tuple[str, int, bool, int, bool]:
8993
if self.is_empty():
9094
raise Exception("Queue is empty")
9195

lightllm/server/core/objs/req.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ def get_status(self):
3232
def is_finished(self):
3333
return self.FINISHED_STOP <= self.status <= self.FINISHED_LENGTH
3434

35+
def is_stoped(self):
36+
return self.status == self.FINISHED_STOP
37+
3538
def get_finish_reason(self):
3639
if self.status == self.FINISHED_STOP:
3740
return "stop"
@@ -97,6 +100,8 @@ class Req(ctypes.Structure):
97100
("mtp_accepted_token_num", ctypes.c_int),
98101
# mtp_step 保存一个mtp使用的常量参数,用于快速访问,不会被外部输入初始化
99102
("_mtp_step", ctypes.c_int),
103+
# stop_str_matched用于判断停止字符串是否匹配成功
104+
("stop_str_matched", ctypes.c_bool),
100105
]
101106

102107
def get_str(self):
@@ -150,6 +155,7 @@ def init(
150155
self.shm_prompt_ids.arr[0 : len(prompt_ids)] = prompt_ids
151156
self.mtp_accepted_token_num = 0
152157
self._mtp_step = get_env_start_args().mtp_step
158+
self.stop_str_matched = False
153159

154160
self.post_init()
155161

@@ -207,7 +213,7 @@ def can_release(self):
207213
ref_count_ok = self.ref_count == 1
208214
can_released_mark = self.can_released_mark
209215

210-
if self.is_aborted and can_released_mark and ref_count_ok:
216+
if (self.is_aborted or self.stop_str_matched) and can_released_mark and ref_count_ok:
211217
return True
212218

213219
if self.finish_status.is_finished() and can_released_mark and ref_count_ok and self.out_tokens_queue.is_empty():

lightllm/server/detokenization/decode_req.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ def init_token_healing_prefix_str(self, token_id_to_token: Dict[int, str], token
3636
return
3737

3838
def need_detoken(self):
39-
if (not self.req.is_aborted) and len(self.output_ids) < self.req.candetoken_out_len:
39+
if (
40+
(not self.req.is_aborted)
41+
and (not self.req.stop_str_matched)
42+
and len(self.output_ids) < self.req.candetoken_out_len
43+
):
4044
return True
4145
return False
4246

@@ -55,6 +59,9 @@ def get_decode_tokens(self):
5559
def can_set_release_mark(self):
5660
if self.req.is_aborted:
5761
return True
62+
if self.req.stop_str_matched:
63+
# httpserver那里必须先处理完请求, 这里才能释放
64+
return self.req.out_tokens_queue.is_empty()
5865
if (
5966
self.req.finish_status.is_finished()
6067
and self.req.candetoken_out_len == len(self.output_ids)

lightllm/server/detokenization/manager.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,32 @@ def handle_loop(self):
101101
logger.exception(f"detoken process has exception {str(e)}")
102102
return
103103

104+
def _stop_sequences_str_matched(self, decode_req, tokenizer):
105+
stop_sequences_str = (
106+
decode_req.req.sample_params.stop_sequences.to_string()
107+
if decode_req.req.sample_params.stop_sequences
108+
else []
109+
)
110+
if not stop_sequences_str or tokenizer is None:
111+
return False
112+
113+
max_stop_str_len = max(len(stop_str) for stop_str in stop_sequences_str) if stop_sequences_str else 0
114+
if max_stop_str_len == 0:
115+
return False
116+
117+
output_len = len(decode_req.output_ids)
118+
tail_token_len = min(decode_req.req.input_len + output_len, max_stop_str_len + 10) # +10 for safety
119+
if tail_token_len > 0:
120+
tail_token_ids = decode_req.req.shm_prompt_ids.arr[
121+
(decode_req.req.input_len + output_len - tail_token_len) : (decode_req.req.input_len + output_len)
122+
]
123+
tail_str = tokenizer.decode(tail_token_ids, skip_special_tokens=False)
124+
for stop_str in stop_sequences_str:
125+
if stop_str in tail_str:
126+
logger.info(f"Found stop sequence in tail: stop_str='{stop_str}', " f"tail_str='{tail_str}'")
127+
return True
128+
return False
129+
104130
def gen_token_out(self):
105131
exist_need_detoken = False
106132
exist_decode = False
@@ -111,6 +137,9 @@ def gen_token_out(self):
111137
special = new_token_id in self.all_special_ids
112138
count_output_tokens = len(decode_req.output_ids)
113139

140+
if decode_req.req.stop_str_matched:
141+
continue
142+
114143
exist_decode = True
115144
new_text = decode_token(
116145
self.tokenizer,
@@ -131,7 +160,18 @@ def gen_token_out(self):
131160
logger.error(
132161
f"error token healing state, prefix_str {decode_req.prefix_str} new_text {new_text}"
133162
)
134-
decode_req.req.out_tokens_queue.push(new_text, src_index, special, count_output_tokens)
163+
164+
# 停止字符串匹配
165+
is_stop_str_matched = False
166+
if not decode_req.req.finish_status.is_stoped() and self._stop_sequences_str_matched(
167+
decode_req, self.tokenizer
168+
):
169+
decode_req.req.stop_str_matched = True
170+
is_stop_str_matched = True
171+
172+
decode_req.req.out_tokens_queue.push(
173+
new_text, src_index, special, count_output_tokens, is_stop_str_matched
174+
)
135175

136176
if decode_req.need_detoken():
137177
exist_need_detoken = True

lightllm/server/httpserver/manager.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,13 @@ async def handle_loop(self):
661661
for _ in range(read_token_count):
662662
if not req.out_tokens_queue.is_empty():
663663

664-
text, src_index, special, count_output_tokens = req.out_tokens_queue.peek()
664+
(
665+
text,
666+
src_index,
667+
special,
668+
count_output_tokens,
669+
is_stop_str_matched,
670+
) = req.out_tokens_queue.peek()
665671
req.cumlogprob += float(req.shm_logprobs.arr[src_index])
666672
metadata = {
667673
"id": int(req.shm_prompt_ids.arr[src_index]),
@@ -679,10 +685,14 @@ async def handle_loop(self):
679685

680686
req.out_tokens_queue.pop_no_ret()
681687

682-
if req.finish_token_index != src_index:
688+
if not is_stop_str_matched and req.finish_token_index != src_index:
683689
token_list.append((req_id, text, metadata, FinishStatus()))
684690
else:
685-
finish_status = FinishStatus(req.finish_status.status)
691+
finish_status = FinishStatus(
692+
req.finish_status.FINISHED_STOP
693+
if is_stop_str_matched
694+
else req.finish_status.status
695+
)
686696
token_list.append((req_id, text, metadata, finish_status))
687697
else:
688698
break

lightllm/server/router/manager.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .batch import Batch, Req
1616
from .model_infer.model_rpc import start_model_process, ModelRpcClient
1717
from .req_queue import build_req_queue
18-
from lightllm.server.core.objs.io_objs import GroupReqIndexes, AbortedReqCmd
18+
from lightllm.server.core.objs.io_objs import GroupReqIndexes, AbortedReqCmd, StopStrMatchedReqCmd
1919
from lightllm.server.core.objs import ShmReqManager, StartArgs
2020
from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient
2121
from .shm_reqs_io_buffer import ShmReqsIOBuffer
@@ -277,8 +277,11 @@ async def _step(self):
277277

278278
self._filter_reqs_from_running_batch()
279279
aborted_reqs = self._get_aborted_reqs_from_running_batch()
280+
stop_str_matched_reqs = self._get_stop_str_reqs_from_running_batch()
280281
if aborted_reqs:
281282
await self._aborted_reqs(aborted_reqs=aborted_reqs)
283+
if stop_str_matched_reqs:
284+
await self._stop_str_matched_reqs(stop_str_matched_reqs=stop_str_matched_reqs)
282285
return
283286

284287
async def _add_batch(self, batch: Batch):
@@ -301,6 +304,15 @@ async def _aborted_reqs(self, aborted_reqs: List[Req]):
301304
self.shm_reqs_io_buffer.set_ready()
302305
return
303306

307+
async def _stop_str_matched_reqs(self, stop_str_matched_reqs: List[Req]):
308+
cmds = [StopStrMatchedReqCmd(req_id=r.request_id) for r in stop_str_matched_reqs]
309+
while not self.shm_reqs_io_buffer.is_empty():
310+
await asyncio.sleep(0.02)
311+
312+
self.shm_reqs_io_buffer.write_obj(cmds)
313+
self.shm_reqs_io_buffer.set_ready()
314+
return
315+
304316
def _add_new_batch_to_running_batch(self, new_batch: Batch):
305317
if self.running_batch is None:
306318
self.running_batch = new_batch
@@ -325,6 +337,15 @@ def _get_aborted_reqs_from_running_batch(self) -> List[Req]:
325337
ans.append(req)
326338
return ans
327339

340+
def _get_stop_str_reqs_from_running_batch(self) -> List[Req]:
341+
ans = []
342+
if self.running_batch is None:
343+
return ans
344+
for req in self.running_batch.reqs:
345+
if req.stop_str_matched:
346+
ans.append(req)
347+
return ans
348+
328349
def _get_paused_req_num(self) -> int:
329350
if self.running_batch is None:
330351
return 0

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,6 @@ def _init_all_state(self):
319319
g_infer_context.req_manager.req_sampling_params_manager.init_req_sampling_params(self)
320320

321321
self.stop_sequences = self.sampling_param.shm_param.stop_sequences.to_list()
322-
self.stop_sequences_str = self.sampling_param.shm_param.stop_sequences.to_string()
323322
# token healing mode 才被使用的管理对象
324323
if self.shm_req.prefix_token_ids.size != 0:
325324
self.prefix_token_ids = self.shm_req.prefix_token_ids.get_token_ids()
@@ -380,10 +379,8 @@ def update_mtp_accepted_token_num(self, accept_token_num: int):
380379
def get_last_gen_token(self):
381380
return self.shm_req.shm_prompt_ids.arr[self.shm_req.input_len + self.cur_output_len - 1]
382381

383-
def update_finish_status(self, eos_ids, output_len: int, tokenizer=None):
384-
if self._stop_sequences_matched(output_len=output_len) or self._stop_sequences_str_matched(
385-
tokenizer, output_len
386-
):
382+
def update_finish_status(self, eos_ids, output_len: int):
383+
if self._stop_sequences_matched(output_len=output_len):
387384
self.finish_status.set_status(FinishStatus.FINISHED_STOP)
388385
elif (
389386
output_len > 0
@@ -408,26 +405,6 @@ def _stop_sequences_matched(self, output_len: int):
408405
return True
409406
return False
410407

411-
def _stop_sequences_str_matched(self, tokenizer, output_len):
412-
if not self.stop_sequences_str or tokenizer is None:
413-
return False
414-
415-
max_stop_str_len = max(len(stop_str) for stop_str in self.stop_sequences_str) if self.stop_sequences_str else 0
416-
if max_stop_str_len == 0:
417-
return False
418-
419-
tail_token_len = min(self.shm_req.input_len + output_len, max_stop_str_len + 10) # +10 for safety
420-
if tail_token_len > 0:
421-
tail_token_ids = self.shm_req.shm_prompt_ids.arr[
422-
(self.shm_req.input_len + output_len - tail_token_len) : (self.shm_req.input_len + output_len)
423-
]
424-
tail_str = tokenizer.decode(tail_token_ids, skip_special_tokens=False)
425-
for stop_str in self.stop_sequences_str:
426-
if stop_str in tail_str:
427-
logger.info(f"Found stop sequence in tail: stop_str='{stop_str}', tail_str='{tail_str}'")
428-
return True
429-
return False
430-
431408
def prefill_need_token_num(self, is_chuncked_prefill: bool):
432409
if is_chuncked_prefill:
433410
input_token_ids = self.get_chuncked_input_token_ids()
@@ -506,7 +483,6 @@ def handle(
506483
eos_ids: List[int],
507484
extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]],
508485
is_master_in_dp: bool,
509-
tokenizer=None,
510486
):
511487
if self.output_len <= 0:
512488
return
@@ -528,7 +504,7 @@ def handle(
528504
return
529505

530506
# 更新判断请求的 finished 状态
531-
req_obj.update_finish_status(eos_ids=eos_ids, output_len=self.output_len, tokenizer=tokenizer)
507+
req_obj.update_finish_status(eos_ids=eos_ids, output_len=self.output_len)
532508

533509
if extra_post_req_handle_func is not None:
534510
extra_post_req_handle_func(req_obj, next_token_id, next_token_logprob)

0 commit comments

Comments
 (0)