11import os
22import ctypes
3- from typing import List , Tuple , Union
3+ from typing import Optional , List , Tuple , Union
44from transformers import GenerationConfig
55from lightllm .server .req_id_generator import MAX_BEST_OF
66
1010
1111# 从环境变量获取最大长度限制
1212STOP_SEQUENCE_MAX_LENGTH = int (os .getenv ("LIGHTLLM_STOP_SEQUENCE_MAX_LENGTH" , 256 ))
13+ STOP_SEQUENCE_STR_MAX_LENGTH = int (os .getenv ("LIGHTLLM_STOP_SEQUENCE_STR_MAX_LENGTH" , 256 ))
1314ALLOWED_TOKEN_IDS_MAX_LENGTH = int (os .getenv ("LIGHTLLM_ALLOWED_TOKEN_IDS_MAX_LENGTH" , 256 ))
1415MAX_STOP_SEQUENCES = int (os .getenv ("LIGHTLLM_MAX_STOP_SEQUENCES" , 10 ))
1516REGULAR_CONSTRAINT_MAX_LENGTH = int (os .getenv ("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH" , 2048 ))
@@ -22,17 +23,30 @@ class StopSequence(ctypes.Structure):
2223 _fields_ = [
2324 ("sequence" , ctypes .c_int * STOP_SEQUENCE_MAX_LENGTH ),
2425 ("size" , ctypes .c_int ),
26+ ("sequence_str" , ctypes .c_char * STOP_SEQUENCE_STR_MAX_LENGTH ),
27+ ("sequence_str_len" , ctypes .c_int ),
2528 ]
2629
27- def initialize (self , sequence : List [int ]):
30+ def initialize (self , sequence : List [int ], sequence_str : Optional [ str ] = None ):
2831 self .size = len (sequence )
2932 assert self .size <= STOP_SEQUENCE_MAX_LENGTH , "stop token length too long."
3033 assert all (isinstance (e , int ) for e in sequence ), "all must be int"
3134 self .sequence [: self .size ] = sequence [:]
3235
33- def to_list (self ):
36+ if sequence_str is not None :
37+ sequence_str_bytes = sequence_str .encode ("utf-8" )
38+ assert len (sequence_str_bytes ) < STOP_SEQUENCE_STR_MAX_LENGTH , "stop sequence string too long."
39+ self .sequence_str = sequence_str_bytes
40+ self .sequence_str_len = len (sequence_str_bytes )
41+ else :
42+ self .sequence_str_len = 0
43+
44+ def to_list (self ) -> List [int ]:
3445 return list (self .sequence [0 : self .size ])
3546
47+ def to_string (self ) -> str :
48+ return bytes (self .sequence_str [0 : self .sequence_str_len ]).decode ("utf-8" )
49+
3650
3751class StopSequenceGroups (ctypes .Structure ):
3852 _pack_ = 4
@@ -41,40 +55,52 @@ class StopSequenceGroups(ctypes.Structure):
4155 ("size" , ctypes .c_int ),
4256 ]
4357
44- def initialize (self , stop_sequences : Union [str , List ], tokenizer ):
58+ def initialize (self , stop_sequences : Union [str , List [Union [List [int ], str ]]], tokenizer ):
59+ if stop_sequences is None :
60+ stop_sequences = []
61+ elif isinstance (stop_sequences , str ):
62+ stop_sequences = [stop_sequences ]
63+
4564 groups : List [List [int ]] = self .stop_sentences_to_token_ids (stop_sequences , tokenizer )
4665 self .size = len (groups )
4766 assert self .size <= MAX_STOP_SEQUENCES , "Too many stop sequence groups."
48- for group_idx in range (self .size ):
49- self .groups [group_idx ].initialize (groups [group_idx ])
5067
51- def stop_sentences_to_token_ids (self , stop_sequences , tokenizer ):
52- if stop_sequences is None :
53- stop_sequences = []
54- else :
55- if isinstance (stop_sequences , str ):
56- stop_sequences = [stop_sequences ]
57-
58- new_stop_sequences = []
59- for stop_info in stop_sequences :
60- if isinstance (stop_info , str ):
61- stop_str_ids = self ._stop_str_to_token_ids (stop_info , tokenizer )
62- if stop_str_ids is not None and len (stop_str_ids ) > 0 :
63- new_stop_sequences .append (stop_str_ids )
64- if isinstance (stop_info , list ):
65- if all (isinstance (x , int ) for x in stop_info ):
66- if len (stop_info ) > 0 :
67- new_stop_sequences .append (stop_info )
68- stop_sequences = new_stop_sequences
69- return stop_sequences
70-
71- def _stop_str_to_token_ids (self , stop_str : str , tokenizer ):
68+ for group_idx in range (self .size ):
69+ if isinstance (stop_sequences [group_idx ], str ):
70+ self .groups [group_idx ].initialize (groups [group_idx ], sequence_str = stop_sequences [group_idx ])
71+ else :
72+ self .groups [group_idx ].initialize (groups [group_idx ])
73+
74+ def stop_sentences_to_token_ids (self , stop_sequences : List [Union [List [int ], str ]], tokenizer ) -> List [List [int ]]:
75+ new_stop_sequences = []
76+ for stop_info in stop_sequences :
77+ if isinstance (stop_info , str ):
78+ stop_str_ids = self ._stop_str_to_token_ids (stop_info , tokenizer )
79+ if stop_str_ids is not None and len (stop_str_ids ) > 0 :
80+ new_stop_sequences .append (stop_str_ids )
81+ if isinstance (stop_info , list ):
82+ if all (isinstance (x , int ) for x in stop_info ):
83+ if len (stop_info ) > 0 :
84+ new_stop_sequences .append (stop_info )
85+ else :
86+ assert False , "stop_sequences item must be type List[int] when it is a list."
87+ return new_stop_sequences
88+
89+ def _stop_str_to_token_ids (self , stop_str : str , tokenizer ) -> List [int ]:
7290 stop_str_ids = tokenizer .encode (stop_str , add_special_tokens = False )
7391 return stop_str_ids
7492
75- def to_list (self ):
93+ def to_list (self ) -> List [ List [ int ]] :
7694 return [self .groups [i ].to_list () for i in range (self .size )]
7795
96+ def to_strings (self ) -> List [str ]:
97+ # 降序匹配,在出现"\n\n"和"\n"情况时,优先匹配“\n\n”
98+ return sorted (
99+ [self .groups [i ].to_string () for i in range (self .size ) if self .groups [i ].sequence_str_len > 0 ],
100+ key = len ,
101+ reverse = True ,
102+ )
103+
78104
79105class RegularConstraint (ctypes .Structure ):
80106 _pack_ = 4
0 commit comments