@@ -217,34 +217,37 @@ def printTokens(self, tokens):
217217# Tokenizer #4 (fast) https://github.com/LoganDark
218218########################################################################################################
219219
220- from typing import Generator
220+ from typing import Generator , Iterable
221221from ast import literal_eval
222222
223223class FastTokenizer :
224224 __slots__ = ('tok2val' , 'root' )
225225
226- def __init__ (self , file_name ):
226+ tok2val : Dict [int , bytes ]
227+ root : Dict [int , Entry ]
228+
229+ def __init__ (self , file_name ) -> None :
227230 self .tok2val = {}
228231 self .root = {}
229232
230233 with open (file_name , 'rt' , encoding = 'utf-8' ) as file :
231234 for line in file :
232- token , value = line .rstrip ().split (' ' , 1 )
233- value , expected_len = value .rsplit (' ' , 1 )
234- value = literal_eval (value )
235- if isinstance (value , str ): value = value .encode ('utf-8' )
236- token , value , expected_len = int (token ), value , int (expected_len )
237- assert len (value ) == expected_len
238- self .add_token (token , value )
239-
240- def add_token (self , token : int , value : bytes ):
235+ token_str , value_repr = line .rstrip ().split (' ' , 1 )
236+ value_repr , len_str = value_repr .rsplit (' ' , 1 )
237+ value_str : Union [bytes , str ] = literal_eval (value_repr )
238+ value = value_str if isinstance (value_str , bytes ) else value_str .encode ('utf-8' )
239+ assert len (value ) == int (len_str )
240+ self .add_token (int (token_str ), value )
241+
242+ def add_token (self , token : int , value : bytes ) -> None :
241243 self .tok2val [token ] = value
242244 pos = self .root
243245 for byte in value [:- 1 ]: pos = pos .setdefault (byte , (None , {}))[1 ]
244246 pos .setdefault (value [- 1 ], (token , {}))
245247
246248 def next_token (self , src : bytes ) -> Optional [int ]:
247- last_token , last = None , self .root
249+ last_token : Optional [int ] = None
250+ last = self .root
248251 for i in range (0 , len (src )):
249252 if current := last .get (src [i ]):
250253 if token := current [0 ]: last_token = token
@@ -255,7 +258,8 @@ def next_token(self, src: bytes) -> Optional[int]:
255258 def encode_bytes (self , src : bytes ) -> Generator [int , None , None ]:
256259 start , stop = 0 , len (src )
257260 while start < stop :
258- last_token , last = None , self .root
261+ last_token : Optional [int ] = None
262+ last = self .root
259263
260264 for i in range (start , stop ):
261265 if current := last .get (src [i ]):
@@ -268,13 +272,13 @@ def encode_bytes(self, src: bytes) -> Generator[int, None, None]:
268272 if last_token : yield last_token
269273 else : break
270274
271- def decode_bytes (self , tokens : list [int ]) -> bytes :
272- return b'' .join (map (self .tok2val .get , tokens ))
275+ def decode_bytes (self , tokens : Iterable [int ]) -> bytes :
276+ return b'' .join (map (self .tok2val .__getitem__ , tokens ))
273277
274278 def encode (self , src : str ) -> Generator [int , None , None ]:
275279 return self .encode_bytes (src .encode ('utf-8' ))
276280
277- def decode (self , tokens : list [int ]) -> str :
281+ def decode (self , tokens : Iterable [int ]) -> str :
278282 return self .decode_bytes (tokens ).decode ('utf-8' )
279283
280284########################################################################################################
0 commit comments