Skip to content

Commit 8cbe6cd

Browse files
committed
hf_tokenizer.py generated
1 parent d99970b commit 8cbe6cd

File tree

3 files changed

+211
-6
lines changed

3 files changed

+211
-6
lines changed

examples/models/llama/runner/generation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def __init__(
7171
self.use_kv_cache = use_kv_cache
7272
self.tokenizer = get_tokenizer(tokenizer_path)
7373
self.device = device
74-
assert vocab_size == self.tokenizer.n_words
74+
# For qwen anything above 151646 is "useless": https://github.com/QwenLM/Qwen2.5/issues/466#issuecomment-2146759706
75+
# assert vocab_size == self.tokenizer.n_words
7576

7677
@abstractmethod
7778
def forward(
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import json
2+
import os
3+
import re
4+
from typing import Dict, List, Optional
5+
6+
class HFTokenizer:
7+
def __init__(self):
8+
self.special_token_encoder: Dict[str, int] = {}
9+
self.special_token_decoder: Dict[int, str] = {}
10+
self.encoder: Dict[str, int] = {}
11+
self.decoder: Dict[int, str] = {}
12+
self.n_words: int = 0
13+
self.bos_id: Optional[int] = None
14+
self.eos_id: Optional[int] = None
15+
self.initialized: bool = False
16+
self.pre_tokenizer_config = None
17+
18+
def load(self, path: str) -> bool:
19+
if os.path.isdir(path):
20+
model_json = os.path.join(path, "tokenizer.json")
21+
model_config_json = os.path.join(path, "tokenizer_config.json")
22+
else:
23+
model_json = path
24+
model_config_json = ""
25+
26+
if not os.path.exists(model_json):
27+
print(f"no tokenizer.json found in {path}")
28+
return False
29+
30+
try:
31+
with open(model_json, "r") as file:
32+
parsed_json = json.load(file)
33+
except json.JSONDecodeError as e:
34+
print(f"Error parsing json file: {e}")
35+
return False
36+
37+
# Parse special tokens
38+
try:
39+
special_tokens = parsed_json["added_tokens"]
40+
for token_info in special_tokens:
41+
token = token_info["content"]
42+
token_id = token_info["id"]
43+
if token in self.special_token_encoder:
44+
print(f"duplicate special token: {token}")
45+
return False
46+
if token_id in self.special_token_decoder:
47+
print(f"duplicate special token id: {token_id}")
48+
return False
49+
self.special_token_encoder[token] = token_id
50+
self.special_token_decoder[token_id] = token
51+
except KeyError as e:
52+
print(f"Could not parse special tokens: {e}")
53+
return False
54+
55+
# Parse standard tokens
56+
try:
57+
vocab = parsed_json["model"]["vocab"]
58+
for token, token_id in vocab.items():
59+
if token_id not in self.special_token_decoder:
60+
if token in self.encoder:
61+
print(f"duplicate token: {token}")
62+
return False
63+
if token_id in self.decoder:
64+
print(f"duplicate token id: {token_id}")
65+
return False
66+
self.encoder[token] = token_id
67+
self.decoder[token_id] = token
68+
except KeyError as e:
69+
print(f"Could not parse tokens: {e}")
70+
return False
71+
72+
self.n_words = len(self.encoder) + len(self.special_token_encoder)
73+
74+
# Parse tokenizer config if available
75+
if model_config_json and os.path.exists(model_config_json):
76+
try:
77+
with open(model_config_json, "r") as file:
78+
config_json = json.load(file)
79+
bos_token = config_json["bos_token"]
80+
eos_token = config_json["eos_token"]
81+
if bos_token not in self.special_token_encoder:
82+
print(f"BOS token {bos_token} not in special tokens")
83+
return False
84+
if eos_token not in self.special_token_encoder:
85+
print(f"EOS token {eos_token} not in special tokens")
86+
return False
87+
self.bos_id = self.special_token_encoder[bos_token]
88+
self.eos_id = self.special_token_encoder[eos_token]
89+
except KeyError as e:
90+
print(f"Could not parse eos/bos from tokenizer config: {e}")
91+
return False
92+
else:
93+
# Guess BOS and EOS tokens
94+
bos_candidates = []
95+
eos_candidates = []
96+
for token in self.special_token_encoder:
97+
if "bos" in token or "begin" in token:
98+
bos_candidates.append(token)
99+
if "eos" in token or "end" in token:
100+
eos_candidates.append(token)
101+
if len(bos_candidates) == 1:
102+
self.bos_id = self.special_token_encoder[bos_candidates[0]]
103+
if len(eos_candidates) == 1:
104+
self.eos_id = self.special_token_encoder[eos_candidates[0]]
105+
if self.bos_id is not None and self.eos_id is None:
106+
self.eos_id = self.bos_id
107+
elif self.eos_id is not None and self.bos_id is None:
108+
self.bos_id = self.eos_id
109+
110+
# Parse pre-tokenizer configuration
111+
try:
112+
self.pre_tokenizer_config = parsed_json.get("pre_tokenizer", {})
113+
except KeyError as e:
114+
print(f"Could not parse pre_tokenizer: {e}")
115+
return False
116+
117+
self.initialized = True
118+
return True
119+
120+
def encode(self, text: str, bos: bool = False, eos: bool = False) -> List[int]:
121+
breakpoint()
122+
if not self.initialized:
123+
raise ValueError("Tokenizer not initialized")
124+
tokens = []
125+
for piece in self._pretokenize(text):
126+
if piece in self.encoder:
127+
tokens.append(self.encoder[piece])
128+
else:
129+
# Handle unknown tokens (e.g., byte pair encoding)
130+
pass
131+
if bos and self.bos_id is not None:
132+
tokens = [self.bos_id] + tokens
133+
if eos and self.eos_id is not None:
134+
tokens.append(self.eos_id)
135+
return tokens
136+
137+
def decode(self, tokens: List[int]) -> str:
138+
if not self.initialized:
139+
raise ValueError("Tokenizer not initialized")
140+
text = ""
141+
for token in tokens:
142+
if token in self.decoder:
143+
text += self.decoder[token]
144+
elif token in self.special_token_decoder:
145+
text += self.special_token_decoder[token]
146+
else:
147+
# Handle unknown tokens
148+
pass
149+
return text
150+
151+
def _pretokenize(self, text: str) -> List[str]:
152+
if not self.pre_tokenizer_config:
153+
return [text] # Default to no pre-tokenization
154+
155+
breakpoint()
156+
pre_tokenizer_type = self.pre_tokenizer_config.get("type", "")
157+
if pre_tokenizer_type == "Split":
158+
return self._split_pretokenize(text)
159+
elif pre_tokenizer_type == "Digits":
160+
return self._digits_pretokenize(text)
161+
elif pre_tokenizer_type == "ByteLevel":
162+
return self._byte_level_pretokenize(text)
163+
elif pre_tokenizer_type == "Sequence":
164+
return self._sequence_pretokenize(text)
165+
else:
166+
return [text] # Unsupported pre-tokenizer type
167+
168+
def _split_pretokenize(self, text: str) -> List[str]:
169+
pattern = self.pre_tokenizer_config.get("pattern", "")
170+
if not pattern:
171+
return [text]
172+
return re.split(f"({pattern})", text)
173+
174+
def _digits_pretokenize(self, text: str) -> List[str]:
175+
individual_digits = self.pre_tokenizer_config.get("individual_digits", False)
176+
if individual_digits:
177+
return list(text) # Split into individual characters
178+
else:
179+
return re.split(r"(\d+)", text) # Split on digits
180+
181+
def _byte_level_pretokenize(self, text: str) -> List[str]:
182+
add_prefix_space = self.pre_tokenizer_config.get("add_prefix_space", False)
183+
pattern = self.pre_tokenizer_config.get("pattern", "")
184+
if add_prefix_space and not text.startswith(" "):
185+
text = " " + text
186+
if not pattern:
187+
pattern = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
188+
return re.findall(pattern, text)
189+
190+
def _sequence_pretokenize(self, text: str) -> List[str]:
191+
pretokenizers = self.pre_tokenizer_config.get("pretokenizers", [])
192+
pieces = [text]
193+
for pretokenizer_config in pretokenizers:
194+
new_pieces = []
195+
for piece in pieces:
196+
new_pieces.extend(self._pretokenize(piece))
197+
pieces = new_pieces
198+
return pieces

extension/llm/tokenizer/utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,18 @@
88
from executorch.extension.llm.tokenizer.tokenizer import (
99
Tokenizer as SentencePieceTokenizer,
1010
)
11+
from executorch.extension.llm.tokenizer.hf_tokenizer import HFTokenizer
1112

1213

1314
def get_tokenizer(tokenizer_path):
14-
try:
15-
tokenizer = SentencePieceTokenizer(model_path=str(tokenizer_path))
16-
except Exception:
17-
print("Using Tiktokenizer")
18-
tokenizer = Tiktoken(model_path=str(tokenizer_path))
15+
if tokenizer_path.endswith(".json"):
16+
print("Using Hugging Face tokenizer")
17+
tokenizer = HFTokenizer()
18+
tokenizer.load(tokenizer_path)
19+
else:
20+
try:
21+
tokenizer = SentencePieceTokenizer(model_path=str(tokenizer_path))
22+
except Exception:
23+
print("Using Tiktokenizer")
24+
tokenizer = Tiktoken(model_path=str(tokenizer_path))
1925
return tokenizer

0 commit comments

Comments
 (0)