11import os
22from typing import List , Dict
33from lightllm .server .core .objs import Req
4+ from lightllm .utils .log_utils import init_logger
5+
6+ logger = init_logger (__name__ )
7+
48
59LIGHTLLM_DECODE_PREFIX_LENGTH = int (os .getenv ("LIGHTLLM_DECODE_PREFIX_LENGTH" , 5 ))
610
@@ -15,6 +19,7 @@ def __init__(
1519 self .group_req_id = req .group_req_id
1620 self .prompt_ids = req .shm_prompt_ids .arr [0 : req .input_len ].tolist ()
1721 self .output_ids = []
22+ self .output_strs = []
1823 self .prefix_offset = max (len (self .prompt_ids ) - LIGHTLLM_DECODE_PREFIX_LENGTH , 0 )
1924
2025 if is_pd_decode_mode :
@@ -26,6 +31,8 @@ def __init__(
2631 self .req = req
2732 self .input_len = self .req .input_len
2833 self .prefix_str = ""
34+ self .stop_strs : List [str ] = self .req .sample_params .stop_sequences .to_strings ()
35+ self .stop_str_max_len = max ([len (e ) for e in self .stop_strs ])
2936
3037 def init_token_healing_prefix_str (self , token_id_to_token : Dict [int , str ], tokenizer ):
3138 tokens = [token_id_to_token [token_id ] for token_id in self .req .prefix_token_ids .get_token_ids ()]
@@ -35,6 +42,24 @@ def init_token_healing_prefix_str(self, token_id_to_token: Dict[int, str], token
3542 self .prefix_str = ""
3643 return
3744
45+ def stop_sequences_str_match (self ) -> bool :
46+ stop_strs = self .stop_strs
47+ if not stop_strs or self .stop_str_max_len == 0 :
48+ return False
49+
50+ tail_token_len = self .stop_str_max_len + 10 # 10 for safety
51+ tail_token_strs = self .output_strs [- tail_token_len :]
52+ tail_str = "" .join (tail_token_strs )
53+
54+ for stop_str in stop_strs :
55+ if stop_str in tail_str :
56+ logger .info (
57+ f"req_id { self .request_id } Found stop sequence in tail: stop_str='{ stop_str } ', "
58+ f"tail_str='{ tail_str } '"
59+ )
60+ return True
61+ return False
62+
3863 def need_detoken (self ):
3964 if (
4065 (not self .req .is_aborted )
0 commit comments