@@ -45,26 +45,25 @@ class MistralTokenizer:
45
45
def __init__ (self , tokenizer : PublicMistralTokenizer ) -> None :
46
46
self .mistral = tokenizer
47
47
self .instruct = tokenizer .instruct_tokenizer
48
- self .tokenizer = tokenizer .instruct_tokenizer .tokenizer
49
48
50
- self .vocab_size = len (self .tokenizer .vocab ())
51
-
52
- assert isinstance (self .tokenizer ,
53
- (Tekkenizer , SentencePieceTokenizer )), type (
54
- self .tokenizer )
55
-
56
- if (is_tekken := isinstance (self .tokenizer , Tekkenizer )):
49
+ tokenizer_ = tokenizer .instruct_tokenizer .tokenizer
50
+ if isinstance (tokenizer_ , Tekkenizer ):
57
51
# Make sure special tokens will not raise
58
- self .tokenizer .special_token_policy = SpecialTokenPolicy .IGNORE
59
-
60
- self ._is_tekken = is_tekken
52
+ tokenizer_ .special_token_policy = SpecialTokenPolicy .IGNORE
53
+
54
+ self ._vocab = {
55
+ token : idx
56
+ for idx , token in enumerate (tokenizer_ .vocab ())
57
+ }
58
+ elif isinstance (tokenizer_ , SentencePieceTokenizer ):
59
+ self ._vocab = {
60
+ token : idx
61
+ for idx , token in enumerate (tokenizer_ .vocab ())
62
+ }
63
+ else :
64
+ raise TypeError (f"Unsupported tokenizer: { type (tokenizer_ )} " )
61
65
62
- # the following attributes are set to fit VLLM's design
63
- self .is_fast = True
64
- self .chat_template = True
65
- self .all_special_ids : List [Any ] = []
66
- self .all_special_tokens : List [Any ] = []
67
- self .all_special_tokens_extended : List [Any ] = []
66
+ self .tokenizer = tokenizer_
68
67
69
68
@classmethod
70
69
def from_pretrained (cls ,
@@ -102,6 +101,38 @@ def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
102
101
revision = revision )
103
102
return tokenizer_file
104
103
104
+ # the following attributes are set to fit VLLM's design
105
+ @property
106
+ def all_special_tokens_extended (self ) -> List [str ]:
107
+ return []
108
+
109
+ @property
110
+ def all_special_tokens (self ) -> List [str ]:
111
+ return []
112
+
113
+ @property
114
+ def all_special_ids (self ) -> List [int ]:
115
+ return []
116
+
117
+ @property
118
+ def bos_token_id (self ) -> int :
119
+ return self .tokenizer .bos_id
120
+
121
+ @property
122
+ def eos_token_id (self ) -> int :
123
+ return self .tokenizer .eos_id
124
+
125
+ @property
126
+ def is_fast (self ) -> bool :
127
+ return True
128
+
129
+ @property
130
+ def vocab_size (self ) -> int :
131
+ return len (self ._vocab )
132
+
133
+ def __len__ (self ) -> int :
134
+ return self .vocab_size
135
+
105
136
def __call__ (
106
137
self ,
107
138
prompt : str ,
@@ -117,9 +148,12 @@ def __call__(
117
148
118
149
return Encoding (input_ids = input_ids )
119
150
120
- def get_added_vocab (self ) -> List [str ]:
151
+ def get_vocab (self ) -> Dict [str , int ]:
152
+ return self ._vocab
153
+
154
+ def get_added_vocab (self ) -> Dict [str , int ]:
121
155
# Mistral tokenizers have no added vocabulary
122
- return []
156
+ return {}
123
157
124
158
def encode (self , prompt : str ) -> List [int ]:
125
159
# `encode` should only be used for prompt completion
@@ -141,7 +175,7 @@ def apply_chat_template(self,
141
175
return encoded .tokens
142
176
143
177
def convert_tokens_to_string (self , tokens : List [str ]) -> str :
144
- if self ._is_tekken :
178
+ if isinstance ( self .tokenizer , Tekkenizer ) :
145
179
return "" .join (tokens )
146
180
else :
147
181
return self .tokenizer .decode (tokens ) # type: ignore[arg-type]
@@ -151,14 +185,11 @@ def decode(self, ids: Union[List[int], int]) -> str:
151
185
ids = [ids ]
152
186
return self .tokenizer .decode (ids )
153
187
154
- @property
155
- def eos_token_id (self ):
156
- return self .tokenizer .eos_id
157
-
158
188
def convert_ids_to_tokens (
159
- self ,
160
- ids : List [int ],
161
- skip_special_tokens : Optional [bool ] = True ) -> List [str ]:
189
+ self ,
190
+ ids : List [int ],
191
+ skip_special_tokens : bool = True ,
192
+ ) -> List [str ]:
162
193
# TODO(Patrick) - potentially allow special tokens to not be skipped
163
194
assert (
164
195
skip_special_tokens
@@ -170,6 +201,3 @@ def convert_ids_to_tokens(
170
201
171
202
tokens = [self .tokenizer .id_to_piece (id ) for id in ids ]
172
203
return tokens
173
-
174
- def __len__ (self ):
175
- return self .vocab_size
0 commit comments