@@ -154,19 +154,22 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
154154 '\uFEFF //' , # unicode_ranges_control, 0xFEFF (BOM)
155155 'Cửa Việt' , # llama-3, ignore_merges = true
156156 '<s>a' , # Phi-3 fail
157- '<unk><|endoftext|><s>' # Phi-3 fail
157+ '<unk><|endoftext|><s>' , # Phi-3 fail
158158 'a\n a' , # TODO: Bert fail
159159 ]
160160
161161
162- def generator_random_special_tokens (special_tokens : list [ str ] , iterations = 100 ) -> Iterator [str ]:
163- special_tokens = set (special_tokens )
162+ def generator_random_special_tokens (tokenizer , iterations = 100 ) -> Iterator [str ]:
163+ special_tokens = set (tokenizer . all_special_tokens )
164164 special_tokens .update ([" " , "\n " , "\t " , "-" , "!" , "one" , "1" , "<s>" , "</s>" ])
165165 special_tokens = list (sorted (special_tokens ))
166166 rand = random .Random ()
167167 for m in range (iterations ):
168168 rand .seed (m )
169169 words = rand .choices (special_tokens , k = 500 )
170+ if tokenizer .add_bos_token : # skip spam warning of double BOS
171+ while words and words [0 ] == tokenizer .bos_token :
172+ words .pop (0 )
170173 yield "" .join (words )
171174
172175
@@ -290,18 +293,19 @@ def main(argv: list[str] = None):
290293 model = LibLlamaModel (LibLlama (), args .vocab_file , mparams = dict (vocab_only = True ), cparams = dict (n_ctx = 4096 ))
291294 tokenizer = AutoTokenizer .from_pretrained (args .dir_tokenizer )
292295
293- def func_tokenize2 (text : str ):
294- return tokenizer .encode (text , add_special_tokens = False )
295-
296- parse_special = all (len (func_tokenize2 (t )) == 1 for t in tokenizer .all_special_tokens )
296+ tokenizer .add_bos_token = getattr (tokenizer , "add_bos_token" , True )
297+ tokenizer .add_eos_token = getattr (tokenizer , "add_eos_token" , False )
297298
298299 def func_tokenize1 (text : str ):
299- return model .tokenize (text , add_special = False , parse_special = parse_special )
300+ return model .tokenize (text , add_special = True , parse_special = True )
301+
302+ def func_tokenize2 (text : str ):
303+ return tokenizer .encode (text , add_special_tokens = True )
300304
301305 vocab = list (sorted (tokenizer .batch_decode (list (tokenizer .get_vocab ().values ()), skip_special_tokens = True )))
302306 test_compare_tokenizer (func_tokenize1 , func_tokenize2 , generator_custom_text ())
303307 test_compare_tokenizer (func_tokenize1 , func_tokenize2 , generator_custom_text_edge_cases ())
304- test_compare_tokenizer (func_tokenize1 , func_tokenize2 , generator_random_special_tokens (tokenizer . all_special_tokens , 10_000 ))
308+ test_compare_tokenizer (func_tokenize1 , func_tokenize2 , generator_random_special_tokens (tokenizer , 10_000 ))
305309 test_compare_tokenizer (func_tokenize1 , func_tokenize2 , generator_vocab_words (vocab ))
306310 test_compare_tokenizer (func_tokenize1 , func_tokenize2 , generator_random_chars (10_000 ))
307311 test_compare_tokenizer (func_tokenize1 , func_tokenize2 , generator_random_vocab_chars (vocab , 10_000 ))
0 commit comments