Skip to content

Commit c9e31b3

Browse files
author
niushengxiao
committed
feat: add stop string matching
1 parent 10a9b66 commit c9e31b3

File tree

5 files changed

+90
-9
lines changed

5 files changed

+90
-9
lines changed

lightllm/server/api_openai.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,9 @@ async def process_single_prompt(prompt: Union[str, List[int]], prompt_index: int
426426
prompt, individual_sampling_params, multimodal_params, request=raw_request
427427
)
428428

429-
return await _collect_generation_results(generator, request, prompt_str, prompt_index)
429+
return await _collect_generation_results(
430+
generator, request, prompt_str, prompt_index, individual_sampling_params
431+
)
430432

431433
tasks = [asyncio.create_task(process_single_prompt(prompt, i)) for i, prompt in enumerate(prompts)]
432434

@@ -485,7 +487,9 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
485487
return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks)
486488

487489

488-
async def _collect_generation_results(generator, request: CompletionRequest, prompt: str, prompt_index: int):
490+
async def _collect_generation_results(
491+
generator, request: CompletionRequest, prompt: str, prompt_index: int, sampling_params: SamplingParams
492+
):
489493
final_output = []
490494
count_output_tokens = 0
491495
finish_reason = None
@@ -516,9 +520,30 @@ async def _collect_generation_results(generator, request: CompletionRequest, pro
516520
finish_reason = finish_status.get_finish_reason()
517521
prompt_tokens = metadata["prompt_tokens"]
518522

523+
# 处理停止序列剔除
524+
final_text = "".join(final_output)
525+
if finish_reason == "stop" and sampling_params.stop_sequences.size > 0:
526+
stop_strings = sampling_params.stop_sequences.to_string()
527+
valid_stop_strings = [s for s in stop_strings if s]
528+
if valid_stop_strings:
529+
max_stop_len = len(valid_stop_strings[0])
530+
search_len = min(len(final_text), max_stop_len + 20) # 搜索长度为最长停止序列长度加20
531+
tail_text = final_text[-search_len:] if search_len > 0 else final_text
532+
tail_start_pos = len(final_text) - search_len
533+
earliest_stop_index = len(final_text)
534+
for stop_str in valid_stop_strings:
535+
stop_index = tail_text.find(stop_str)
536+
if stop_index != -1:
537+
actual_stop_index = tail_start_pos + stop_index
538+
if actual_stop_index < earliest_stop_index:
539+
earliest_stop_index = actual_stop_index
540+
541+
if earliest_stop_index < len(final_text):
542+
final_text = final_text[:earliest_stop_index]
543+
519544
return {
520545
"index": prompt_index,
521-
"text": "".join(final_output),
546+
"text": final_text,
522547
"finish_reason": finish_reason,
523548
"prompt_tokens": prompt_tokens,
524549
"completion_tokens": count_output_tokens,

lightllm/server/core/objs/sampling_params.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
# 从环境变量获取最大长度限制
1212
STOP_SEQUENCE_MAX_LENGTH = int(os.getenv("LIGHTLLM_STOP_SEQUENCE_MAX_LENGTH", 256))
13+
STOP_SEQUENCE_STR_MAX_LENGTH = int(os.getenv("LIGHTLLM_STOP_SEQUENCE_STR_MAX_LENGTH", 256))
1314
ALLOWED_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_ALLOWED_TOKEN_IDS_MAX_LENGTH", 256))
1415
MAX_STOP_SEQUENCES = int(os.getenv("LIGHTLLM_MAX_STOP_SEQUENCES", 10))
1516
REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048))
@@ -22,17 +23,27 @@ class StopSequence(ctypes.Structure):
2223
_fields_ = [
2324
("sequence", ctypes.c_int * STOP_SEQUENCE_MAX_LENGTH),
2425
("size", ctypes.c_int),
26+
("sequence_str", ctypes.c_char * STOP_SEQUENCE_STR_MAX_LENGTH),
27+
("sequence_str_len", ctypes.c_int),
2528
]
2629

27-
def initialize(self, sequence: List[int]):
30+
def initialize(self, sequence: List[int], sequence_str: str = ""):
2831
self.size = len(sequence)
2932
assert self.size <= STOP_SEQUENCE_MAX_LENGTH, "stop token length too long."
3033
assert all(isinstance(e, int) for e in sequence), "all must be int"
3134
self.sequence[: self.size] = sequence[:]
3235

36+
sequence_str_bytes = sequence_str.encode("utf-8")
37+
assert len(sequence_str_bytes) < STOP_SEQUENCE_STR_MAX_LENGTH, "stop sequence string too long."
38+
self.sequence_str = sequence_str_bytes
39+
self.sequence_str_len = len(sequence_str_bytes)
40+
3341
def to_list(self):
3442
return list(self.sequence[0 : self.size])
3543

44+
def to_string(self):
45+
return bytes(self.sequence_str[0 : self.sequence_str_len]).decode("utf-8")
46+
3647

3748
class StopSequenceGroups(ctypes.Structure):
3849
_pack_ = 4
@@ -45,8 +56,10 @@ def initialize(self, stop_sequences: Union[str, List], tokenizer):
4556
groups: List[List[int]] = self.stop_sentences_to_token_ids(stop_sequences, tokenizer)
4657
self.size = len(groups)
4758
assert self.size <= MAX_STOP_SEQUENCES, "Too many stop sequence groups."
59+
if isinstance(stop_sequences, str):
60+
stop_sequences = [stop_sequences]
4861
for group_idx in range(self.size):
49-
self.groups[group_idx].initialize(groups[group_idx])
62+
self.groups[group_idx].initialize(groups[group_idx], stop_sequences[group_idx])
5063

5164
def stop_sentences_to_token_ids(self, stop_sequences, tokenizer):
5265
if stop_sequences is None:
@@ -75,6 +88,10 @@ def _stop_str_to_token_ids(self, stop_str: str, tokenizer):
7588
def to_list(self):
7689
return [self.groups[i].to_list() for i in range(self.size)]
7790

91+
def to_string(self):
92+
# 降序匹配,在出现"\n\n"和"\n"情况时,优先匹配“\n\n”
93+
return sorted([self.groups[i].to_string() for i in range(self.size)], key=len, reverse=True)
94+
7895

7996
class RegularConstraint(ctypes.Structure):
8097
_pack_ = 4

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ def _init_all_state(self):
319319
g_infer_context.req_manager.req_sampling_params_manager.init_req_sampling_params(self)
320320

321321
self.stop_sequences = self.sampling_param.shm_param.stop_sequences.to_list()
322+
self.stop_sequences_str = self.sampling_param.shm_param.stop_sequences.to_string()
322323
# token healing mode 才被使用的管理对象
323324
if self.shm_req.prefix_token_ids.size != 0:
324325
self.prefix_token_ids = self.shm_req.prefix_token_ids.get_token_ids()
@@ -379,8 +380,10 @@ def update_mtp_accepted_token_num(self, accept_token_num: int):
379380
def get_last_gen_token(self):
380381
return self.shm_req.shm_prompt_ids.arr[self.shm_req.input_len + self.cur_output_len - 1]
381382

382-
def update_finish_status(self, eos_ids, output_len: int):
383-
if self._stop_sequences_matched(output_len=output_len):
383+
def update_finish_status(self, eos_ids, output_len: int, tokenizer=None):
384+
if self._stop_sequences_matched(output_len=output_len) or self._stop_sequences_str_matched(
385+
tokenizer, output_len
386+
):
384387
self.finish_status.set_status(FinishStatus.FINISHED_STOP)
385388
elif (
386389
output_len > 0
@@ -405,6 +408,26 @@ def _stop_sequences_matched(self, output_len: int):
405408
return True
406409
return False
407410

411+
def _stop_sequences_str_matched(self, tokenizer, output_len):
412+
if not self.stop_sequences_str or tokenizer is None:
413+
return False
414+
415+
max_stop_str_len = max(len(stop_str) for stop_str in self.stop_sequences_str) if self.stop_sequences_str else 0
416+
if max_stop_str_len == 0:
417+
return False
418+
419+
tail_token_len = min(self.shm_req.input_len + output_len, max_stop_str_len + 10) # +10 for safety
420+
if tail_token_len > 0:
421+
tail_token_ids = self.shm_req.shm_prompt_ids.arr[
422+
(self.shm_req.input_len + output_len - tail_token_len) : (self.shm_req.input_len + output_len)
423+
]
424+
tail_str = tokenizer.decode(tail_token_ids, skip_special_tokens=False)
425+
for stop_str in self.stop_sequences_str:
426+
if stop_str in tail_str:
427+
logger.info(f"Found stop sequence in tail: stop_str='{stop_str}', tail_str='{tail_str}'")
428+
return True
429+
return False
430+
408431
def prefill_need_token_num(self, is_chuncked_prefill: bool):
409432
if is_chuncked_prefill:
410433
input_token_ids = self.get_chuncked_input_token_ids()
@@ -483,6 +506,7 @@ def handle(
483506
eos_ids: List[int],
484507
extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]],
485508
is_master_in_dp: bool,
509+
tokenizer=None,
486510
):
487511
if self.output_len <= 0:
488512
return
@@ -504,7 +528,7 @@ def handle(
504528
return
505529

506530
# 更新判断请求的 finished 状态
507-
req_obj.update_finish_status(eos_ids=eos_ids, output_len=self.output_len)
531+
req_obj.update_finish_status(eos_ids=eos_ids, output_len=self.output_len, tokenizer=tokenizer)
508532

509533
if extra_post_req_handle_func is not None:
510534
extra_post_req_handle_func(req_obj, next_token_id, next_token_logprob)

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,12 @@
2626
from lightllm.utils.dist_utils import get_dp_world_size, get_global_dp_rank, get_current_rank_in_dp
2727
from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node, get_node_world_size
2828
from lightllm.utils.dist_utils import get_dp_rank_in_node, create_new_group_for_current_node
29-
from lightllm.utils.envs_utils import get_env_start_args
29+
from lightllm.utils.envs_utils import get_env_start_args, enable_stop_string_match
3030
from lightllm.distributed import dist_group_manager
3131
from lightllm.server.router.shm_reqs_io_buffer import ShmReqsIOBuffer
3232
from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack
3333
from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel
34+
from lightllm.server.tokenizer import get_tokenizer
3435

3536

3637
class ModeBackend:
@@ -506,6 +507,14 @@ def _post_handle(
506507
extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于
507508
约束输出等模式,设置自己请求内部的状态机的状态,并添加额外的停止判定条件等。
508509
"""
510+
if enable_stop_string_match():
511+
if not hasattr(self, "tokenizer"):
512+
self.tokenizer = get_tokenizer(
513+
self.args.model_dir, self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code
514+
)
515+
else:
516+
self.tokenizer = None
517+
509518
for req_obj, next_token_id, next_token_logprob, pack in zip(
510519
run_reqs, next_token_ids, next_token_logprobs, run_reqs_update_packs
511520
):
@@ -517,6 +526,7 @@ def _post_handle(
517526
eos_ids=self.eos_id,
518527
extra_post_req_handle_func=extra_post_req_handle_func,
519528
is_master_in_dp=self.is_master_in_dp,
529+
tokenizer=self.tokenizer,
520530
)
521531

522532
g_infer_context.req_manager.req_sampling_params_manager.update_reqs_token_counter(

lightllm/utils/envs_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ def get_lightllm_gunicorn_keep_alive():
6868
return int(os.getenv("LIGHTLMM_GUNICORN_KEEP_ALIVE", 10))
6969

7070

71+
@lru_cache(maxsize=None)
72+
def enable_stop_string_match():
73+
return os.getenv("ENABLE_STOP_STRING_MATCH", "False").upper() in ["ON", "TRUE", "1"]
74+
75+
7176
@lru_cache(maxsize=None)
7277
def get_lightllm_websocket_max_message_size():
7378
"""

0 commit comments

Comments
 (0)