Skip to content

Commit 320bd2b

Browse files
committed
Typecheck more with mypyc
1 parent 44b52ba commit 320bd2b

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

tokenizer/rwkv_tokenizer.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -217,34 +217,37 @@ def printTokens(self, tokens):
217217
# Tokenizer #4 (fast) https://github.com/LoganDark
218218
########################################################################################################
219219

220-
from typing import Generator
220+
from typing import Generator, Iterable
221221
from ast import literal_eval
222222

223223
class FastTokenizer:
224224
__slots__ = ('tok2val', 'root')
225225

226-
def __init__(self, file_name):
226+
tok2val: Dict[int, bytes]
227+
root: Dict[int, Entry]
228+
229+
def __init__(self, file_name) -> None:
227230
self.tok2val = {}
228231
self.root = {}
229232

230233
with open(file_name, 'rt', encoding = 'utf-8') as file:
231234
for line in file:
232-
token, value = line.rstrip().split(' ', 1)
233-
value, expected_len = value.rsplit(' ', 1)
234-
value = literal_eval(value)
235-
if isinstance(value, str): value = value.encode('utf-8')
236-
token, value, expected_len = int(token), value, int(expected_len)
237-
assert len(value) == expected_len
238-
self.add_token(token, value)
239-
240-
def add_token(self, token: int, value: bytes):
235+
token_str, value_repr = line.rstrip().split(' ', 1)
236+
value_repr, len_str = value_repr.rsplit(' ', 1)
237+
value_str: Union[bytes, str] = literal_eval(value_repr)
238+
value = value_str if isinstance(value_str, bytes) else value_str.encode('utf-8')
239+
assert len(value) == int(len_str)
240+
self.add_token(int(token_str), value)
241+
242+
def add_token(self, token: int, value: bytes) -> None:
241243
self.tok2val[token] = value
242244
pos = self.root
243245
for byte in value[:-1]: pos = pos.setdefault(byte, (None, {}))[1]
244246
pos.setdefault(value[-1], (token, {}))
245247

246248
def next_token(self, src: bytes) -> Optional[int]:
247-
last_token, last = None, self.root
249+
last_token: Optional[int] = None
250+
last = self.root
248251
for i in range(0, len(src)):
249252
if current := last.get(src[i]):
250253
if token := current[0]: last_token = token
@@ -255,7 +258,8 @@ def next_token(self, src: bytes) -> Optional[int]:
255258
def encode_bytes(self, src: bytes) -> Generator[int, None, None]:
256259
start, stop = 0, len(src)
257260
while start < stop:
258-
last_token, last = None, self.root
261+
last_token: Optional[int] = None
262+
last = self.root
259263

260264
for i in range(start, stop):
261265
if current := last.get(src[i]):
@@ -268,13 +272,13 @@ def encode_bytes(self, src: bytes) -> Generator[int, None, None]:
268272
if last_token: yield last_token
269273
else: break
270274

271-
def decode_bytes(self, tokens: list[int]) -> bytes:
272-
return b''.join(map(self.tok2val.get, tokens))
275+
def decode_bytes(self, tokens: Iterable[int]) -> bytes:
276+
return b''.join(map(self.tok2val.__getitem__, tokens))
273277

274278
def encode(self, src: str) -> Generator[int, None, None]:
275279
return self.encode_bytes(src.encode('utf-8'))
276280

277-
def decode(self, tokens: list[int]) -> str:
281+
def decode(self, tokens: Iterable[int]) -> str:
278282
return self.decode_bytes(tokens).decode('utf-8')
279283

280284
########################################################################################################

0 commit comments

Comments
 (0)