11import os
22import ctypes
3- from typing import List , Tuple , Union
3+ from typing import Optional , List , Tuple , Union
44from transformers import GenerationConfig
55from lightllm .server .req_id_generator import MAX_BEST_OF
66
@@ -27,21 +27,24 @@ class StopSequence(ctypes.Structure):
2727 ("sequence_str_len" , ctypes .c_int ),
2828 ]
2929
30- def initialize (self , sequence : List [int ], sequence_str : str = "" ):
30+ def initialize (self , sequence : List [int ], sequence_str : Optional [ str ] = None ):
3131 self .size = len (sequence )
3232 assert self .size <= STOP_SEQUENCE_MAX_LENGTH , "stop token length too long."
3333 assert all (isinstance (e , int ) for e in sequence ), "all must be int"
3434 self .sequence [: self .size ] = sequence [:]
3535
36- sequence_str_bytes = sequence_str .encode ("utf-8" )
37- assert len (sequence_str_bytes ) < STOP_SEQUENCE_STR_MAX_LENGTH , "stop sequence string too long."
38- self .sequence_str = sequence_str_bytes
39- self .sequence_str_len = len (sequence_str_bytes )
36+ if sequence_str is not None :
37+ sequence_str_bytes = sequence_str .encode ("utf-8" )
38+ assert len (sequence_str_bytes ) < STOP_SEQUENCE_STR_MAX_LENGTH , "stop sequence string too long."
39+ self .sequence_str = sequence_str_bytes
40+ self .sequence_str_len = len (sequence_str_bytes )
41+ else :
42+ self .sequence_str_len = 0
4043
41- def to_list (self ):
44+ def to_list (self ) -> List [ int ] :
4245 return list (self .sequence [0 : self .size ])
4346
44- def to_string (self ):
47+ def to_string (self ) -> str :
4548 return bytes (self .sequence_str [0 : self .sequence_str_len ]).decode ("utf-8" )
4649
4750
@@ -52,45 +55,51 @@ class StopSequenceGroups(ctypes.Structure):
5255 ("size" , ctypes .c_int ),
5356 ]
5457
55- def initialize (self , stop_sequences : Union [str , List ], tokenizer ):
58+ def initialize (self , stop_sequences : Union [str , List [Union [List [int ], str ]]], tokenizer ):
59+ if stop_sequences is None :
60+ stop_sequences = []
61+ elif isinstance (stop_sequences , str ):
62+ stop_sequences = [stop_sequences ]
63+
5664 groups : List [List [int ]] = self .stop_sentences_to_token_ids (stop_sequences , tokenizer )
5765 self .size = len (groups )
5866 assert self .size <= MAX_STOP_SEQUENCES , "Too many stop sequence groups."
59- if isinstance (stop_sequences , str ):
60- stop_sequences = [stop_sequences ]
61- for group_idx in range (self .size ):
62- self .groups [group_idx ].initialize (groups [group_idx ], stop_sequences [group_idx ])
6367
64- def stop_sentences_to_token_ids (self , stop_sequences , tokenizer ):
65- if stop_sequences is None :
66- stop_sequences = []
67- else :
68- if isinstance (stop_sequences , str ):
69- stop_sequences = [stop_sequences ]
70-
71- new_stop_sequences = []
72- for stop_info in stop_sequences :
73- if isinstance (stop_info , str ):
74- stop_str_ids = self ._stop_str_to_token_ids (stop_info , tokenizer )
75- if stop_str_ids is not None and len (stop_str_ids ) > 0 :
76- new_stop_sequences .append (stop_str_ids )
77- if isinstance (stop_info , list ):
78- if all (isinstance (x , int ) for x in stop_info ):
79- if len (stop_info ) > 0 :
80- new_stop_sequences .append (stop_info )
81- stop_sequences = new_stop_sequences
82- return stop_sequences
83-
84- def _stop_str_to_token_ids (self , stop_str : str , tokenizer ):
68+ for group_idx in range (self .size ):
69+ if isinstance (stop_sequences [group_idx ], str ):
70+ self .groups [group_idx ].initialize (groups [group_idx ], sequence_str = stop_sequences [group_idx ])
71+ else :
72+ self .groups [group_idx ].initialize (groups [group_idx ])
73+
74+ def stop_sentences_to_token_ids (self , stop_sequences : List [Union [List [int ], str ]], tokenizer ) -> List [List [int ]]:
75+ new_stop_sequences = []
76+ for stop_info in stop_sequences :
77+ if isinstance (stop_info , str ):
78+ stop_str_ids = self ._stop_str_to_token_ids (stop_info , tokenizer )
79+ if stop_str_ids is not None and len (stop_str_ids ) > 0 :
80+ new_stop_sequences .append (stop_str_ids )
81+ if isinstance (stop_info , list ):
82+ if all (isinstance (x , int ) for x in stop_info ):
83+ if len (stop_info ) > 0 :
84+ new_stop_sequences .append (stop_info )
85+ else :
86+ assert False , "stop_sequences item must be type List[int] when it is a list."
87+ return new_stop_sequences
88+
89+ def _stop_str_to_token_ids (self , stop_str : str , tokenizer ) -> List [int ]:
8590 stop_str_ids = tokenizer .encode (stop_str , add_special_tokens = False )
8691 return stop_str_ids
8792
88- def to_list (self ):
93+ def to_list (self ) -> List [ List [ int ]] :
8994 return [self .groups [i ].to_list () for i in range (self .size )]
9095
91- def to_string (self ):
96+ def to_strings (self ) -> List [ str ] :
9297 # 降序匹配,在出现"\n\n"和"\n"情况时,优先匹配“\n\n”
93- return sorted ([self .groups [i ].to_string () for i in range (self .size )], key = len , reverse = True )
98+ return sorted (
99+ [self .groups [i ].to_string () for i in range (self .size ) if self .groups [i ].sequence_str_len > 0 ],
100+ key = len ,
101+ reverse = True ,
102+ )
94103
95104
96105class RegularConstraint (ctypes .Structure ):
0 commit comments