Skip to content

Commit 2f6e877

Browse files
author
niushengxiao
committed
feat: move stop string matching to detokenization
1 parent c9e31b3 commit 2f6e877

File tree

9 files changed

+62
-51
lines changed

9 files changed

+62
-51
lines changed

lightllm/server/api_openai.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@ async def _collect_generation_results(
539539
earliest_stop_index = actual_stop_index
540540

541541
if earliest_stop_index < len(final_text):
542+
logger.info(f"removed stop sequence in tail: '{final_text[earliest_stop_index:]}'")
542543
final_text = final_text[:earliest_stop_index]
543544

544545
return {

lightllm/server/core/objs/out_token_circlequeue.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,17 @@ class QueueItem(ctypes.Structure):
1414
("special", ctypes.c_bool),
1515
("count_output_tokens", ctypes.c_int),
1616
("src_index", ctypes.c_int), # 在源token队列的索引位置
17+
("force_stop", ctypes.c_bool), # 强制停止所有推理并让客户端返回
1718
]
1819

1920
def __init__(self):
2021
self.data_len = 0
2122
self.src_index = -1
2223
self.special = False
2324
self.count_output_tokens = -1
25+
self.force_stop = False
2426

25-
def set(self, token_str: str, src_index: int, special: bool, count_output_tokens: int):
27+
def set(self, token_str: str, src_index: int, special: bool, count_output_tokens: int, force_stop: bool):
2628
str_bytes = token_str.encode("utf-8")
2729
assert (
2830
len(str_bytes) <= LIGHTLLM_TOKEN_MAX_BYTES
@@ -32,6 +34,7 @@ def set(self, token_str: str, src_index: int, special: bool, count_output_tokens
3234
self.src_index = src_index
3335
self.special = special
3436
self.count_output_tokens = count_output_tokens
37+
self.force_stop = force_stop
3538
return
3639

3740
def get(self):
@@ -40,6 +43,7 @@ def get(self):
4043
self.src_index,
4144
self.special,
4245
self.count_output_tokens,
46+
self.force_stop,
4347
)
4448

4549

@@ -62,13 +66,13 @@ def is_empty(self):
6266
def is_full(self):
6367
return (self.tail + 1) % LIGHTLLM_OUT_TOKEN_QUEUE_SIZE == self.head
6468

65-
def push(self, token_str: str, src_index: int, special: bool, count_output_tokens: int):
69+
def push(self, token_str: str, src_index: int, special: bool, count_output_tokens: int, force_stop: bool):
6670
if self.is_full():
6771
raise Exception("Queue is full")
6872

6973
# 添加元素
7074
item: QueueItem = self.items[self.tail]
71-
item.set(token_str, src_index, special, count_output_tokens)
75+
item.set(token_str, src_index, special, count_output_tokens, force_stop)
7276

7377
# 更新尾部
7478
self.tail = (self.tail + 1) % LIGHTLLM_OUT_TOKEN_QUEUE_SIZE
@@ -85,7 +89,7 @@ def pop(self) -> Tuple[str, int, bool, int]:
8589
self.head = (self.head + 1) % LIGHTLLM_OUT_TOKEN_QUEUE_SIZE
8690
return result
8791

88-
def peek(self) -> Tuple[str, int, bool, int]:
92+
def peek(self) -> Tuple[str, int, bool, int, bool]:
8993
if self.is_empty():
9094
raise Exception("Queue is empty")
9195

lightllm/server/core/objs/req.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ def get_status(self):
3232
def is_finished(self):
3333
return self.FINISHED_STOP <= self.status <= self.FINISHED_LENGTH
3434

35+
def is_stoped(self):
36+
return self.status == self.FINISHED_STOP
37+
3538
def get_finish_reason(self):
3639
if self.status == self.FINISHED_STOP:
3740
return "stop"

lightllm/server/detokenization/decode_req.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def get_decode_tokens(self):
5454

5555
def can_set_release_mark(self):
5656
if self.req.is_aborted:
57-
return True
57+
# httpserver那里必须先处理完请求, 这里才能释放
58+
return self.req.out_tokens_queue.is_empty()
5859
if (
5960
self.req.finish_status.is_finished()
6061
and self.req.candetoken_out_len == len(self.output_ids)

lightllm/server/detokenization/manager.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,32 @@ def handle_loop(self):
101101
logger.exception(f"detoken process has exception {str(e)}")
102102
return
103103

104+
def _stop_sequences_str_matched(self, decode_req, tokenizer):
105+
stop_sequences_str = (
106+
decode_req.req.sample_params.stop_sequences.to_string()
107+
if decode_req.req.sample_params.stop_sequences
108+
else []
109+
)
110+
if not stop_sequences_str or tokenizer is None:
111+
return False
112+
113+
max_stop_str_len = max(len(stop_str) for stop_str in stop_sequences_str) if stop_sequences_str else 0
114+
if max_stop_str_len == 0:
115+
return False
116+
117+
output_len = len(decode_req.output_ids)
118+
tail_token_len = min(decode_req.req.input_len + output_len, max_stop_str_len + 10) # +10 for safety
119+
if tail_token_len > 0:
120+
tail_token_ids = decode_req.req.shm_prompt_ids.arr[
121+
(decode_req.req.input_len + output_len - tail_token_len) : (decode_req.req.input_len + output_len)
122+
]
123+
tail_str = tokenizer.decode(tail_token_ids, skip_special_tokens=False)
124+
for stop_str in stop_sequences_str:
125+
if stop_str in tail_str:
126+
logger.info(f"Found stop sequence in tail: stop_str='{stop_str}', " f"tail_str='{tail_str}'")
127+
return True
128+
return False
129+
104130
def gen_token_out(self):
105131
exist_need_detoken = False
106132
exist_decode = False
@@ -111,6 +137,9 @@ def gen_token_out(self):
111137
special = new_token_id in self.all_special_ids
112138
count_output_tokens = len(decode_req.output_ids)
113139

140+
if decode_req.req.is_aborted:
141+
continue
142+
114143
exist_decode = True
115144
new_text = decode_token(
116145
self.tokenizer,
@@ -131,7 +160,16 @@ def gen_token_out(self):
131160
logger.error(
132161
f"error token healing state, prefix_str {decode_req.prefix_str} new_text {new_text}"
133162
)
134-
decode_req.req.out_tokens_queue.push(new_text, src_index, special, count_output_tokens)
163+
164+
# 停止字符串匹配
165+
force_stop = False
166+
if not decode_req.req.finish_status.is_stoped() and self._stop_sequences_str_matched(
167+
decode_req, self.tokenizer
168+
):
169+
decode_req.req.is_aborted = True
170+
force_stop = True
171+
172+
decode_req.req.out_tokens_queue.push(new_text, src_index, special, count_output_tokens, force_stop)
135173

136174
if decode_req.need_detoken():
137175
exist_need_detoken = True

lightllm/server/httpserver/manager.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ async def handle_loop(self):
661661
for _ in range(read_token_count):
662662
if not req.out_tokens_queue.is_empty():
663663

664-
text, src_index, special, count_output_tokens = req.out_tokens_queue.peek()
664+
text, src_index, special, count_output_tokens, force_stop = req.out_tokens_queue.peek()
665665
req.cumlogprob += float(req.shm_logprobs.arr[src_index])
666666
metadata = {
667667
"id": int(req.shm_prompt_ids.arr[src_index]),
@@ -679,10 +679,12 @@ async def handle_loop(self):
679679

680680
req.out_tokens_queue.pop_no_ret()
681681

682-
if req.finish_token_index != src_index:
682+
if not force_stop and req.finish_token_index != src_index:
683683
token_list.append((req_id, text, metadata, FinishStatus()))
684684
else:
685-
finish_status = FinishStatus(req.finish_status.status)
685+
finish_status = FinishStatus(
686+
req.finish_status.FINISHED_STOP if force_stop else req.finish_status.status
687+
)
686688
token_list.append((req_id, text, metadata, finish_status))
687689
else:
688690
break

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -380,10 +380,8 @@ def update_mtp_accepted_token_num(self, accept_token_num: int):
380380
def get_last_gen_token(self):
381381
return self.shm_req.shm_prompt_ids.arr[self.shm_req.input_len + self.cur_output_len - 1]
382382

383-
def update_finish_status(self, eos_ids, output_len: int, tokenizer=None):
384-
if self._stop_sequences_matched(output_len=output_len) or self._stop_sequences_str_matched(
385-
tokenizer, output_len
386-
):
383+
def update_finish_status(self, eos_ids, output_len: int):
384+
if self._stop_sequences_matched(output_len=output_len):
387385
self.finish_status.set_status(FinishStatus.FINISHED_STOP)
388386
elif (
389387
output_len > 0
@@ -408,26 +406,6 @@ def _stop_sequences_matched(self, output_len: int):
408406
return True
409407
return False
410408

411-
def _stop_sequences_str_matched(self, tokenizer, output_len):
412-
if not self.stop_sequences_str or tokenizer is None:
413-
return False
414-
415-
max_stop_str_len = max(len(stop_str) for stop_str in self.stop_sequences_str) if self.stop_sequences_str else 0
416-
if max_stop_str_len == 0:
417-
return False
418-
419-
tail_token_len = min(self.shm_req.input_len + output_len, max_stop_str_len + 10) # +10 for safety
420-
if tail_token_len > 0:
421-
tail_token_ids = self.shm_req.shm_prompt_ids.arr[
422-
(self.shm_req.input_len + output_len - tail_token_len) : (self.shm_req.input_len + output_len)
423-
]
424-
tail_str = tokenizer.decode(tail_token_ids, skip_special_tokens=False)
425-
for stop_str in self.stop_sequences_str:
426-
if stop_str in tail_str:
427-
logger.info(f"Found stop sequence in tail: stop_str='{stop_str}', tail_str='{tail_str}'")
428-
return True
429-
return False
430-
431409
def prefill_need_token_num(self, is_chuncked_prefill: bool):
432410
if is_chuncked_prefill:
433411
input_token_ids = self.get_chuncked_input_token_ids()
@@ -506,7 +484,6 @@ def handle(
506484
eos_ids: List[int],
507485
extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]],
508486
is_master_in_dp: bool,
509-
tokenizer=None,
510487
):
511488
if self.output_len <= 0:
512489
return
@@ -528,7 +505,7 @@ def handle(
528505
return
529506

530507
# 更新判断请求的 finished 状态
531-
req_obj.update_finish_status(eos_ids=eos_ids, output_len=self.output_len, tokenizer=tokenizer)
508+
req_obj.update_finish_status(eos_ids=eos_ids, output_len=self.output_len)
532509

533510
if extra_post_req_handle_func is not None:
534511
extra_post_req_handle_func(req_obj, next_token_id, next_token_logprob)

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,11 @@
2626
from lightllm.utils.dist_utils import get_dp_world_size, get_global_dp_rank, get_current_rank_in_dp
2727
from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node, get_node_world_size
2828
from lightllm.utils.dist_utils import get_dp_rank_in_node, create_new_group_for_current_node
29-
from lightllm.utils.envs_utils import get_env_start_args, enable_stop_string_match
29+
from lightllm.utils.envs_utils import get_env_start_args
3030
from lightllm.distributed import dist_group_manager
3131
from lightllm.server.router.shm_reqs_io_buffer import ShmReqsIOBuffer
3232
from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack
3333
from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel
34-
from lightllm.server.tokenizer import get_tokenizer
3534

3635

3736
class ModeBackend:
@@ -507,14 +506,6 @@ def _post_handle(
507506
extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于
508507
约束输出等模式,设置自己请求内部的状态机的状态,并添加额外的停止判定条件等。
509508
"""
510-
if enable_stop_string_match():
511-
if not hasattr(self, "tokenizer"):
512-
self.tokenizer = get_tokenizer(
513-
self.args.model_dir, self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code
514-
)
515-
else:
516-
self.tokenizer = None
517-
518509
for req_obj, next_token_id, next_token_logprob, pack in zip(
519510
run_reqs, next_token_ids, next_token_logprobs, run_reqs_update_packs
520511
):
@@ -526,7 +517,6 @@ def _post_handle(
526517
eos_ids=self.eos_id,
527518
extra_post_req_handle_func=extra_post_req_handle_func,
528519
is_master_in_dp=self.is_master_in_dp,
529-
tokenizer=self.tokenizer,
530520
)
531521

532522
g_infer_context.req_manager.req_sampling_params_manager.update_reqs_token_counter(

lightllm/utils/envs_utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,6 @@ def get_lightllm_gunicorn_keep_alive():
6868
return int(os.getenv("LIGHTLMM_GUNICORN_KEEP_ALIVE", 10))
6969

7070

71-
@lru_cache(maxsize=None)
72-
def enable_stop_string_match():
73-
return os.getenv("ENABLE_STOP_STRING_MATCH", "False").upper() in ["ON", "TRUE", "1"]
74-
75-
7671
@lru_cache(maxsize=None)
7772
def get_lightllm_websocket_max_message_size():
7873
"""

0 commit comments

Comments
 (0)