@@ -204,18 +204,16 @@ def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
204
204
self .version : int = int (_mistral_version_str .split ("v" )[- 1 ])
205
205
206
206
tokenizer_ = tokenizer .instruct_tokenizer .tokenizer
207
- from mistral_common .tokens .tokenizers .tekken import (
208
- SpecialTokenPolicy , Tekkenizer )
207
+ from mistral_common .tokens .tokenizers .base import SpecialTokenPolicy
208
+ from mistral_common .tokens .tokenizers .tekken import Tekkenizer
209
+
209
210
self .is_tekken = isinstance (tokenizer_ , Tekkenizer )
210
211
from mistral_common .tokens .tokenizers .sentencepiece import (
211
212
SentencePieceTokenizer )
212
213
self .is_spm = isinstance (tokenizer_ , SentencePieceTokenizer )
213
- if self .is_tekken :
214
- # Make sure special tokens will not raise
215
- tokenizer_ .special_token_policy = SpecialTokenPolicy .IGNORE
216
- elif self .is_spm :
217
- pass
218
- else :
214
+ self ._special_token_policy = (SpecialTokenPolicy .IGNORE
215
+ if self .is_tekken else None )
216
+ if not (self .is_tekken or self .is_spm ):
219
217
raise TypeError (f"Unsupported tokenizer: { type (tokenizer_ )} " )
220
218
221
219
self ._vocab = tokenizer_ .vocab ()
@@ -430,7 +428,8 @@ def _token_to_id(t: str):
430
428
return self .tokenizer .unk_id
431
429
432
430
ids = [_token_to_id (t ) for t in tokens ]
433
- decoded = self .tokenizer .decode (ids )
431
+ decoded = self .tokenizer .decode (ids ,
432
+ self ._special_token_policy )
434
433
else :
435
434
decoded = "" .join (tokens )
436
435
else :
@@ -444,15 +443,17 @@ def _token_to_id(t: str):
444
443
if token in special_tokens :
445
444
if regular_tokens :
446
445
decoded_list .append (
447
- self .tokenizer .decode (regular_tokens ))
446
+ self .tokenizer .decode (regular_tokens ,
447
+ self ._special_token_policy ))
448
448
regular_tokens = []
449
449
decoded_list .append (token )
450
450
else :
451
451
regular_tokens .append (token )
452
452
453
453
if regular_tokens :
454
454
decoded_list .append (
455
- self .tokenizer .decode (regular_tokens )) # type: ignore
455
+ self .tokenizer .decode (regular_tokens ,
456
+ self ._special_token_policy ))
456
457
457
458
decoded = '' .join (decoded_list )
458
459
@@ -470,7 +471,7 @@ def decode(self,
470
471
471
472
if isinstance (ids , int ):
472
473
ids = [ids ]
473
- return self .tokenizer .decode (ids )
474
+ return self .tokenizer .decode (ids , self . _special_token_policy )
474
475
475
476
def convert_ids_to_tokens (
476
477
self ,
@@ -511,6 +512,9 @@ def convert_ids_to_tokens(
511
512
# See: https://github.com/vllm-project/vllm/pull/8640
512
513
# https://github.com/vllm-project/vllm/pull/9625
513
514
# if underlying tokenizeir is sentencepiece, we just add "�"
514
- tokens = [self .tokenizer .id_to_byte_piece (id ) for id in ids ]
515
+ tokens = [
516
+ self .tokenizer .id_to_byte_piece (id , self ._special_token_policy )
517
+ for id in ids
518
+ ]
515
519
516
520
return tokens
0 commit comments