From 8cbe6cdd7958bebf861a72c69de7b081baeb287e Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Mon, 10 Feb 2025 15:04:11 -0800 Subject: [PATCH 1/8] hf_tokenizer.py generated --- examples/models/llama/runner/generation.py | 3 +- extension/llm/tokenizer/hf_tokenizer.py | 198 +++++++++++++++++++++ extension/llm/tokenizer/utils.py | 16 +- 3 files changed, 211 insertions(+), 6 deletions(-) create mode 100644 extension/llm/tokenizer/hf_tokenizer.py diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 891ce20db3e..9d7a604598b 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -71,7 +71,8 @@ def __init__( self.use_kv_cache = use_kv_cache self.tokenizer = get_tokenizer(tokenizer_path) self.device = device - assert vocab_size == self.tokenizer.n_words + # For qwen anything above 151646 is "useless": https://github.com/QwenLM/Qwen2.5/issues/466#issuecomment-2146759706 + # assert vocab_size == self.tokenizer.n_words @abstractmethod def forward( diff --git a/extension/llm/tokenizer/hf_tokenizer.py b/extension/llm/tokenizer/hf_tokenizer.py new file mode 100644 index 00000000000..777952dd465 --- /dev/null +++ b/extension/llm/tokenizer/hf_tokenizer.py @@ -0,0 +1,198 @@ +import json +import os +import re +from typing import Dict, List, Optional + +class HFTokenizer: + def __init__(self): + self.special_token_encoder: Dict[str, int] = {} + self.special_token_decoder: Dict[int, str] = {} + self.encoder: Dict[str, int] = {} + self.decoder: Dict[int, str] = {} + self.n_words: int = 0 + self.bos_id: Optional[int] = None + self.eos_id: Optional[int] = None + self.initialized: bool = False + self.pre_tokenizer_config = None + + def load(self, path: str) -> bool: + if os.path.isdir(path): + model_json = os.path.join(path, "tokenizer.json") + model_config_json = os.path.join(path, "tokenizer_config.json") + else: + model_json = path + model_config_json = "" + + if not os.path.exists(model_json): + print(f"no tokenizer.json found in {path}") + return False + + try: + with open(model_json, "r") as file: + parsed_json = json.load(file) + except json.JSONDecodeError as e: + print(f"Error parsing json file: {e}") + return False + + # Parse special tokens + try: + special_tokens = parsed_json["added_tokens"] + for token_info in special_tokens: + token = token_info["content"] + token_id = token_info["id"] + if token in self.special_token_encoder: + print(f"duplicate special token: {token}") + return False + if token_id in self.special_token_decoder: + print(f"duplicate special token id: {token_id}") + return False + self.special_token_encoder[token] = token_id + self.special_token_decoder[token_id] = token + except KeyError as e: + print(f"Could not parse special tokens: {e}") + return False + + # Parse standard tokens + try: + vocab = parsed_json["model"]["vocab"] + for token, token_id in vocab.items(): + if token_id not in self.special_token_decoder: + if token in self.encoder: + print(f"duplicate token: {token}") + return False + if token_id in self.decoder: + print(f"duplicate token id: {token_id}") + return False + self.encoder[token] = token_id + self.decoder[token_id] = token + except KeyError as e: + print(f"Could not parse tokens: {e}") + return False + + self.n_words = len(self.encoder) + len(self.special_token_encoder) + + # Parse tokenizer config if available + if model_config_json and os.path.exists(model_config_json): + try: + with open(model_config_json, "r") as file: + config_json = json.load(file) + bos_token = config_json["bos_token"] + eos_token = config_json["eos_token"] + if bos_token not in self.special_token_encoder: + print(f"BOS token {bos_token} not in special tokens") + return False + if eos_token not in self.special_token_encoder: + print(f"EOS token {eos_token} not in special tokens") + return False + self.bos_id = self.special_token_encoder[bos_token] + self.eos_id = self.special_token_encoder[eos_token] + except KeyError as e: + print(f"Could not parse eos/bos from tokenizer config: {e}") + return False + else: + # Guess BOS and EOS tokens + bos_candidates = [] + eos_candidates = [] + for token in self.special_token_encoder: + if "bos" in token or "begin" in token: + bos_candidates.append(token) + if "eos" in token or "end" in token: + eos_candidates.append(token) + if len(bos_candidates) == 1: + self.bos_id = self.special_token_encoder[bos_candidates[0]] + if len(eos_candidates) == 1: + self.eos_id = self.special_token_encoder[eos_candidates[0]] + if self.bos_id is not None and self.eos_id is None: + self.eos_id = self.bos_id + elif self.eos_id is not None and self.bos_id is None: + self.bos_id = self.eos_id + + # Parse pre-tokenizer configuration + try: + self.pre_tokenizer_config = parsed_json.get("pre_tokenizer", {}) + except KeyError as e: + print(f"Could not parse pre_tokenizer: {e}") + return False + + self.initialized = True + return True + + def encode(self, text: str, bos: bool = False, eos: bool = False) -> List[int]: + breakpoint() + if not self.initialized: + raise ValueError("Tokenizer not initialized") + tokens = [] + for piece in self._pretokenize(text): + if piece in self.encoder: + tokens.append(self.encoder[piece]) + else: + # Handle unknown tokens (e.g., byte pair encoding) + pass + if bos and self.bos_id is not None: + tokens = [self.bos_id] + tokens + if eos and self.eos_id is not None: + tokens.append(self.eos_id) + return tokens + + def decode(self, tokens: List[int]) -> str: + if not self.initialized: + raise ValueError("Tokenizer not initialized") + text = "" + for token in tokens: + if token in self.decoder: + text += self.decoder[token] + elif token in self.special_token_decoder: + text += self.special_token_decoder[token] + else: + # Handle unknown tokens + pass + return text + + def _pretokenize(self, text: str) -> List[str]: + if not self.pre_tokenizer_config: + return [text] # Default to no pre-tokenization + + breakpoint() + pre_tokenizer_type = self.pre_tokenizer_config.get("type", "") + if pre_tokenizer_type == "Split": + return self._split_pretokenize(text) + elif pre_tokenizer_type == "Digits": + return self._digits_pretokenize(text) + elif pre_tokenizer_type == "ByteLevel": + return self._byte_level_pretokenize(text) + elif pre_tokenizer_type == "Sequence": + return self._sequence_pretokenize(text) + else: + return [text] # Unsupported pre-tokenizer type + + def _split_pretokenize(self, text: str) -> List[str]: + pattern = self.pre_tokenizer_config.get("pattern", "") + if not pattern: + return [text] + return re.split(f"({pattern})", text) + + def _digits_pretokenize(self, text: str) -> List[str]: + individual_digits = self.pre_tokenizer_config.get("individual_digits", False) + if individual_digits: + return list(text) # Split into individual characters + else: + return re.split(r"(\d+)", text) # Split on digits + + def _byte_level_pretokenize(self, text: str) -> List[str]: + add_prefix_space = self.pre_tokenizer_config.get("add_prefix_space", False) + pattern = self.pre_tokenizer_config.get("pattern", "") + if add_prefix_space and not text.startswith(" "): + text = " " + text + if not pattern: + pattern = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+" + return re.findall(pattern, text) + + def _sequence_pretokenize(self, text: str) -> List[str]: + pretokenizers = self.pre_tokenizer_config.get("pretokenizers", []) + pieces = [text] + for pretokenizer_config in pretokenizers: + new_pieces = [] + for piece in pieces: + new_pieces.extend(self._pretokenize(piece)) + pieces = new_pieces + return pieces diff --git a/extension/llm/tokenizer/utils.py b/extension/llm/tokenizer/utils.py index 126a1203274..b7d7570d453 100644 --- a/extension/llm/tokenizer/utils.py +++ b/extension/llm/tokenizer/utils.py @@ -8,12 +8,18 @@ from executorch.extension.llm.tokenizer.tokenizer import ( Tokenizer as SentencePieceTokenizer, ) +from executorch.extension.llm.tokenizer.hf_tokenizer import HFTokenizer def get_tokenizer(tokenizer_path): - try: - tokenizer = SentencePieceTokenizer(model_path=str(tokenizer_path)) - except Exception: - print("Using Tiktokenizer") - tokenizer = Tiktoken(model_path=str(tokenizer_path)) + if tokenizer_path.endswith(".json"): + print("Using Hugging Face tokenizer") + tokenizer = HFTokenizer() + tokenizer.load(tokenizer_path) + else: + try: + tokenizer = SentencePieceTokenizer(model_path=str(tokenizer_path)) + except Exception: + print("Using Tiktokenizer") + tokenizer = Tiktoken(model_path=str(tokenizer_path)) return tokenizer From 88b3394f2cd20ffcbf8022b099cd2f54e1705115 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Mon, 10 Feb 2025 16:51:05 -0800 Subject: [PATCH 2/8] Add hugging face tokenizer --- examples/models/llama/runner/generation.py | 16 +++++++++++----- extension/llm/tokenizer/utils.py | 17 ++++++++++++++--- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 9d7a604598b..ecd067d9dd0 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -102,7 +102,8 @@ def generate( # noqa: C901 ) current_token = next_token(logits, temperature, top_p) - print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) + # print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) + print(f"{self.tokenizer.decode([current_token])}", end="", flush=True) tokens = prompt_tokens + [current_token] while len(tokens) < max_seq_len: @@ -132,7 +133,8 @@ def generate( # noqa: C901 ): break - print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) + # print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) + print(f"{self.tokenizer.decode([current_token])}", end="", flush=True) print("\n") return tokens if echo else tokens[len(prompt_tokens) :] @@ -160,7 +162,8 @@ def text_completion( This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness. """ return self.generate( - prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False), + # prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False), + prompt_tokens=self.tokenizer.encode(prompt).ids, max_seq_len=self.max_seq_len, temperature=temperature, top_p=top_p, @@ -194,9 +197,12 @@ def chat_completion( prompt = input("Me: ") while prompt and prompt != exit_prompt: print("LLM: ", end="", flush=True) + # prompt_tokens = self.tokenizer.encode( + # self._format_prompt(prompt), bos=True, eos=False + # ) prompt_tokens = self.tokenizer.encode( - self._format_prompt(prompt), bos=True, eos=False - ) + self._format_prompt(prompt) + ).ids generated_tokens = self.generate( prompt_tokens=pre_stop_token + prompt_tokens, max_seq_len=max_seq_len, diff --git a/extension/llm/tokenizer/utils.py b/extension/llm/tokenizer/utils.py index b7d7570d453..a0502cb3212 100644 --- a/extension/llm/tokenizer/utils.py +++ b/extension/llm/tokenizer/utils.py @@ -13,9 +13,20 @@ def get_tokenizer(tokenizer_path): if tokenizer_path.endswith(".json"): - print("Using Hugging Face tokenizer") - tokenizer = HFTokenizer() - tokenizer.load(tokenizer_path) + # print("Using Hugging Face tokenizer") + # tokenizer = HFTokenizer() + # tokenizer.load(tokenizer_path) + + from tokenizers import Tokenizer + + # Load the tokenizer from the tokenizer.json file + tokenizer = Tokenizer.from_file(tokenizer_path) + + # from tokenizers import SentencePieceBPETokenizer + + # tokenizer = SentencePieceBPETokenizer(tokenizer_path) + tokenizer.n_words = tokenizer.get_vocab_size() + breakpoint() else: try: tokenizer = SentencePieceTokenizer(model_path=str(tokenizer_path)) From 70fd1feb6330ca5447f50831fa7cd023009db84c Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 12 Feb 2025 10:54:54 -0800 Subject: [PATCH 3/8] Qwen runs with HF tokenizer --- examples/models/llama/install_requirements.sh | 12 +++---- examples/models/llama/runner/generation.py | 33 ++++++++++--------- examples/models/llama/runner/native.py | 16 +++++++++ .../llama3_2_vision/runner/generation.py | 13 ++++---- extension/llm/tokenizer/hf_tokenizer.py | 1 + extension/llm/tokenizer/utils.py | 23 +++++++------ 6 files changed, 59 insertions(+), 39 deletions(-) diff --git a/examples/models/llama/install_requirements.sh b/examples/models/llama/install_requirements.sh index f72d0ea3cc2..cca6ede1d79 100755 --- a/examples/models/llama/install_requirements.sh +++ b/examples/models/llama/install_requirements.sh @@ -5,14 +5,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# Install sentencepiece for llama tokenizer. +# Install tiktoken for tokenizer. +# Install tokenizers for hf .json tokenizer. # Install snakeviz for cProfile flamegraph -# Install sentencepiece for llama tokenizer -pip install snakeviz sentencepiece - -# Install lm-eval for Model Evaluation with lm-evalution-harness -# Install tiktoken for tokenizer -pip install lm_eval==0.4.5 -pip install tiktoken blobfile +# Install lm-eval for Model Evaluation with lm-evalution-harness. +pip install tiktoken sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile # Call the install helper for further setup python examples/models/llama/install_requirement_helper.py diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index ecd067d9dd0..23d55b5461a 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -48,7 +48,9 @@ def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int: class LlamaRunner(ABC): def __init__( self, + *, tokenizer_path: str, + tokenizer_config_path: Optional[str] = None, max_seq_len: int, max_batch_size: int, use_kv_cache: bool, @@ -59,20 +61,23 @@ def __init__( Constructor. Args: - tokenizer_path: path to tokenizer.model file. - max_seq_len: max length of the output sequence, after which the output will be clipped. - max_batch_size: max batch size. - use_kv_cache: whether to use a KV cache. - vocab_size: number of items in the vocab. - device: device to run the runner on. + tokenizer_path: path to tokenizer.model file. + max_seq_len: max length of the output sequence, after which the output will be clipped. + max_batch_size: max batch size. + use_kv_cache: whether to use a KV cache. + vocab_size: number of items in the vocab. + device: device to run the runner on. """ self.max_seq_len = max_seq_len self.max_batch_size = max_batch_size self.use_kv_cache = use_kv_cache - self.tokenizer = get_tokenizer(tokenizer_path) + self.tokenizer = get_tokenizer(tokenizer_path, tokenizer_config_path) self.device = device - # For qwen anything above 151646 is "useless": https://github.com/QwenLM/Qwen2.5/issues/466#issuecomment-2146759706 - # assert vocab_size == self.tokenizer.n_words + # For some models like qwen, mismatch is acceptable: https://github.com/QwenLM/Qwen2.5/issues/466#issuecomment-2146759706 + if vocab_size != self.tokenizer.n_words: + print( + "Warning - given vocab_size in params is unequal to tokenizer vocab size." + ) @abstractmethod def forward( @@ -102,8 +107,7 @@ def generate( # noqa: C901 ) current_token = next_token(logits, temperature, top_p) - # print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) - print(f"{self.tokenizer.decode([current_token])}", end="", flush=True) + print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) tokens = prompt_tokens + [current_token] while len(tokens) < max_seq_len: @@ -133,8 +137,7 @@ def generate( # noqa: C901 ): break - # print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) - print(f"{self.tokenizer.decode([current_token])}", end="", flush=True) + print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) print("\n") return tokens if echo else tokens[len(prompt_tokens) :] @@ -200,9 +203,7 @@ def chat_completion( # prompt_tokens = self.tokenizer.encode( # self._format_prompt(prompt), bos=True, eos=False # ) - prompt_tokens = self.tokenizer.encode( - self._format_prompt(prompt) - ).ids + prompt_tokens = self.tokenizer.encode(self._format_prompt(prompt)).ids generated_tokens = self.generate( prompt_tokens=pre_stop_token + prompt_tokens, max_seq_len=max_seq_len, diff --git a/examples/models/llama/runner/native.py b/examples/models/llama/runner/native.py index a6b055ced95..d71815edba5 100644 --- a/examples/models/llama/runner/native.py +++ b/examples/models/llama/runner/native.py @@ -37,6 +37,7 @@ def __init__(self, args): params = json.loads(f.read()) super().__init__( tokenizer_path=args.tokenizer, + tokenizer_config_path=args.tokenizer_config, max_seq_len=args.max_len, max_batch_size=1, use_kv_cache=args.kv_cache, @@ -56,6 +57,14 @@ def forward( )[0] +def validate_args(args) -> None: + if args.tokenizer and args.tokenizer.endswith(".json"): + if not args.tokenizer_config: + raise TypeError( + "Json tokenizers require an accompanying tokenizer config (--tokenizer_config) to be specified." + ) + + def build_args_parser() -> argparse.ArgumentParser: # TODO: merge these with build_args_parser from export_llama_lib. parser = argparse.ArgumentParser() @@ -85,6 +94,12 @@ def build_args_parser() -> argparse.ArgumentParser: default=None, ) + parser.add_argument( + "--tokenizer_config", + type=str, + default=None, + ) + parser.add_argument( "--prompt", type=str, @@ -116,6 +131,7 @@ def build_args_parser() -> argparse.ArgumentParser: def main() -> None: parser = build_args_parser() args = parser.parse_args() + validate_args(args) runner = NativeLlamaRunner(args) generated_tokens = runner.text_completion( prompt=args.prompt, diff --git a/examples/models/llama3_2_vision/runner/generation.py b/examples/models/llama3_2_vision/runner/generation.py index e17760fd852..88a6a44a535 100644 --- a/examples/models/llama3_2_vision/runner/generation.py +++ b/examples/models/llama3_2_vision/runner/generation.py @@ -13,6 +13,7 @@ class TorchTuneLlamaRunner(LlamaRunner): def __init__( self, + *, tokenizer_path: str, max_seq_len: int, max_batch_size: int, @@ -21,12 +22,12 @@ def __init__( device: str = "cpu", ): super().__init__( - tokenizer_path, - max_seq_len, - max_batch_size, - use_kv_cache, - vocab_size, - device, + tokenizer_path=tokenizer_path, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + use_kv_cache=use_kv_cache, + vocab_size=vocab_size, + device=device, ) self.causal_mask = torch.tril( diff --git a/extension/llm/tokenizer/hf_tokenizer.py b/extension/llm/tokenizer/hf_tokenizer.py index 777952dd465..5d588a0caf8 100644 --- a/extension/llm/tokenizer/hf_tokenizer.py +++ b/extension/llm/tokenizer/hf_tokenizer.py @@ -3,6 +3,7 @@ import re from typing import Dict, List, Optional + class HFTokenizer: def __init__(self): self.special_token_encoder: Dict[str, int] = {} diff --git a/extension/llm/tokenizer/utils.py b/extension/llm/tokenizer/utils.py index a0502cb3212..745656de38e 100644 --- a/extension/llm/tokenizer/utils.py +++ b/extension/llm/tokenizer/utils.py @@ -4,29 +4,32 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import json +from typing import Optional + from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken +from executorch.extension.llm.tokenizer.hf_tokenizer import HFTokenizer from executorch.extension.llm.tokenizer.tokenizer import ( Tokenizer as SentencePieceTokenizer, ) -from executorch.extension.llm.tokenizer.hf_tokenizer import HFTokenizer -def get_tokenizer(tokenizer_path): +def get_tokenizer(tokenizer_path: str, tokenizer_config_path: Optional[str] = None): if tokenizer_path.endswith(".json"): - # print("Using Hugging Face tokenizer") - # tokenizer = HFTokenizer() - # tokenizer.load(tokenizer_path) - from tokenizers import Tokenizer # Load the tokenizer from the tokenizer.json file tokenizer = Tokenizer.from_file(tokenizer_path) - - # from tokenizers import SentencePieceBPETokenizer - # tokenizer = SentencePieceBPETokenizer(tokenizer_path) + # export_llama expects n_words attribute. tokenizer.n_words = tokenizer.get_vocab_size() - breakpoint() + # Keep in line with internal tokenizer apis. + tokenizer.decode_token = lambda token: tokenizer.decode([token]) + + if tokenizer_config_path: + with open(tokenizer_config_path) as f: + tokenizer_config = json.load(f) + tokenizer.eos_id = tokenizer.token_to_id(tokenizer_config["eos_token"]) else: try: tokenizer = SentencePieceTokenizer(model_path=str(tokenizer_path)) From 5834d14bf22c498df8f3e1301b724b3f4f00d52f Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 12 Feb 2025 11:20:17 -0800 Subject: [PATCH 4/8] Fix encode, remove generated python tokenizer --- examples/models/llama/runner/generation.py | 10 +- extension/llm/tokenizer/hf_tokenizer.py | 199 --------------------- extension/llm/tokenizer/utils.py | 3 +- 3 files changed, 6 insertions(+), 206 deletions(-) delete mode 100644 extension/llm/tokenizer/hf_tokenizer.py diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 23d55b5461a..3e9ceb34af5 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -165,8 +165,7 @@ def text_completion( This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness. """ return self.generate( - # prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False), - prompt_tokens=self.tokenizer.encode(prompt).ids, + prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False), max_seq_len=self.max_seq_len, temperature=temperature, top_p=top_p, @@ -200,10 +199,9 @@ def chat_completion( prompt = input("Me: ") while prompt and prompt != exit_prompt: print("LLM: ", end="", flush=True) - # prompt_tokens = self.tokenizer.encode( - # self._format_prompt(prompt), bos=True, eos=False - # ) - prompt_tokens = self.tokenizer.encode(self._format_prompt(prompt)).ids + prompt_tokens = self.tokenizer.encode( + self._format_prompt(prompt), bos=True, eos=False + ) generated_tokens = self.generate( prompt_tokens=pre_stop_token + prompt_tokens, max_seq_len=max_seq_len, diff --git a/extension/llm/tokenizer/hf_tokenizer.py b/extension/llm/tokenizer/hf_tokenizer.py deleted file mode 100644 index 5d588a0caf8..00000000000 --- a/extension/llm/tokenizer/hf_tokenizer.py +++ /dev/null @@ -1,199 +0,0 @@ -import json -import os -import re -from typing import Dict, List, Optional - - -class HFTokenizer: - def __init__(self): - self.special_token_encoder: Dict[str, int] = {} - self.special_token_decoder: Dict[int, str] = {} - self.encoder: Dict[str, int] = {} - self.decoder: Dict[int, str] = {} - self.n_words: int = 0 - self.bos_id: Optional[int] = None - self.eos_id: Optional[int] = None - self.initialized: bool = False - self.pre_tokenizer_config = None - - def load(self, path: str) -> bool: - if os.path.isdir(path): - model_json = os.path.join(path, "tokenizer.json") - model_config_json = os.path.join(path, "tokenizer_config.json") - else: - model_json = path - model_config_json = "" - - if not os.path.exists(model_json): - print(f"no tokenizer.json found in {path}") - return False - - try: - with open(model_json, "r") as file: - parsed_json = json.load(file) - except json.JSONDecodeError as e: - print(f"Error parsing json file: {e}") - return False - - # Parse special tokens - try: - special_tokens = parsed_json["added_tokens"] - for token_info in special_tokens: - token = token_info["content"] - token_id = token_info["id"] - if token in self.special_token_encoder: - print(f"duplicate special token: {token}") - return False - if token_id in self.special_token_decoder: - print(f"duplicate special token id: {token_id}") - return False - self.special_token_encoder[token] = token_id - self.special_token_decoder[token_id] = token - except KeyError as e: - print(f"Could not parse special tokens: {e}") - return False - - # Parse standard tokens - try: - vocab = parsed_json["model"]["vocab"] - for token, token_id in vocab.items(): - if token_id not in self.special_token_decoder: - if token in self.encoder: - print(f"duplicate token: {token}") - return False - if token_id in self.decoder: - print(f"duplicate token id: {token_id}") - return False - self.encoder[token] = token_id - self.decoder[token_id] = token - except KeyError as e: - print(f"Could not parse tokens: {e}") - return False - - self.n_words = len(self.encoder) + len(self.special_token_encoder) - - # Parse tokenizer config if available - if model_config_json and os.path.exists(model_config_json): - try: - with open(model_config_json, "r") as file: - config_json = json.load(file) - bos_token = config_json["bos_token"] - eos_token = config_json["eos_token"] - if bos_token not in self.special_token_encoder: - print(f"BOS token {bos_token} not in special tokens") - return False - if eos_token not in self.special_token_encoder: - print(f"EOS token {eos_token} not in special tokens") - return False - self.bos_id = self.special_token_encoder[bos_token] - self.eos_id = self.special_token_encoder[eos_token] - except KeyError as e: - print(f"Could not parse eos/bos from tokenizer config: {e}") - return False - else: - # Guess BOS and EOS tokens - bos_candidates = [] - eos_candidates = [] - for token in self.special_token_encoder: - if "bos" in token or "begin" in token: - bos_candidates.append(token) - if "eos" in token or "end" in token: - eos_candidates.append(token) - if len(bos_candidates) == 1: - self.bos_id = self.special_token_encoder[bos_candidates[0]] - if len(eos_candidates) == 1: - self.eos_id = self.special_token_encoder[eos_candidates[0]] - if self.bos_id is not None and self.eos_id is None: - self.eos_id = self.bos_id - elif self.eos_id is not None and self.bos_id is None: - self.bos_id = self.eos_id - - # Parse pre-tokenizer configuration - try: - self.pre_tokenizer_config = parsed_json.get("pre_tokenizer", {}) - except KeyError as e: - print(f"Could not parse pre_tokenizer: {e}") - return False - - self.initialized = True - return True - - def encode(self, text: str, bos: bool = False, eos: bool = False) -> List[int]: - breakpoint() - if not self.initialized: - raise ValueError("Tokenizer not initialized") - tokens = [] - for piece in self._pretokenize(text): - if piece in self.encoder: - tokens.append(self.encoder[piece]) - else: - # Handle unknown tokens (e.g., byte pair encoding) - pass - if bos and self.bos_id is not None: - tokens = [self.bos_id] + tokens - if eos and self.eos_id is not None: - tokens.append(self.eos_id) - return tokens - - def decode(self, tokens: List[int]) -> str: - if not self.initialized: - raise ValueError("Tokenizer not initialized") - text = "" - for token in tokens: - if token in self.decoder: - text += self.decoder[token] - elif token in self.special_token_decoder: - text += self.special_token_decoder[token] - else: - # Handle unknown tokens - pass - return text - - def _pretokenize(self, text: str) -> List[str]: - if not self.pre_tokenizer_config: - return [text] # Default to no pre-tokenization - - breakpoint() - pre_tokenizer_type = self.pre_tokenizer_config.get("type", "") - if pre_tokenizer_type == "Split": - return self._split_pretokenize(text) - elif pre_tokenizer_type == "Digits": - return self._digits_pretokenize(text) - elif pre_tokenizer_type == "ByteLevel": - return self._byte_level_pretokenize(text) - elif pre_tokenizer_type == "Sequence": - return self._sequence_pretokenize(text) - else: - return [text] # Unsupported pre-tokenizer type - - def _split_pretokenize(self, text: str) -> List[str]: - pattern = self.pre_tokenizer_config.get("pattern", "") - if not pattern: - return [text] - return re.split(f"({pattern})", text) - - def _digits_pretokenize(self, text: str) -> List[str]: - individual_digits = self.pre_tokenizer_config.get("individual_digits", False) - if individual_digits: - return list(text) # Split into individual characters - else: - return re.split(r"(\d+)", text) # Split on digits - - def _byte_level_pretokenize(self, text: str) -> List[str]: - add_prefix_space = self.pre_tokenizer_config.get("add_prefix_space", False) - pattern = self.pre_tokenizer_config.get("pattern", "") - if add_prefix_space and not text.startswith(" "): - text = " " + text - if not pattern: - pattern = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+" - return re.findall(pattern, text) - - def _sequence_pretokenize(self, text: str) -> List[str]: - pretokenizers = self.pre_tokenizer_config.get("pretokenizers", []) - pieces = [text] - for pretokenizer_config in pretokenizers: - new_pieces = [] - for piece in pieces: - new_pieces.extend(self._pretokenize(piece)) - pieces = new_pieces - return pieces diff --git a/extension/llm/tokenizer/utils.py b/extension/llm/tokenizer/utils.py index 745656de38e..07bfac31151 100644 --- a/extension/llm/tokenizer/utils.py +++ b/extension/llm/tokenizer/utils.py @@ -8,7 +8,6 @@ from typing import Optional from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken -from executorch.extension.llm.tokenizer.hf_tokenizer import HFTokenizer from executorch.extension.llm.tokenizer.tokenizer import ( Tokenizer as SentencePieceTokenizer, ) @@ -25,6 +24,8 @@ def get_tokenizer(tokenizer_path: str, tokenizer_config_path: Optional[str] = No tokenizer.n_words = tokenizer.get_vocab_size() # Keep in line with internal tokenizer apis. tokenizer.decode_token = lambda token: tokenizer.decode([token]) + original_encode = tokenizer.encode + tokenizer.encode = lambda prompt, **kwargs: original_encode(prompt).ids if tokenizer_config_path: with open(tokenizer_config_path) as f: From 421288b1a83f1c98904f7b4e2778830df159be5d Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 12 Feb 2025 11:22:54 -0800 Subject: [PATCH 5/8] Comment / lint --- extension/llm/tokenizer/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/extension/llm/tokenizer/utils.py b/extension/llm/tokenizer/utils.py index 07bfac31151..b6efa2e7bc2 100644 --- a/extension/llm/tokenizer/utils.py +++ b/extension/llm/tokenizer/utils.py @@ -17,12 +17,10 @@ def get_tokenizer(tokenizer_path: str, tokenizer_config_path: Optional[str] = No if tokenizer_path.endswith(".json"): from tokenizers import Tokenizer - # Load the tokenizer from the tokenizer.json file tokenizer = Tokenizer.from_file(tokenizer_path) - # export_llama expects n_words attribute. - tokenizer.n_words = tokenizer.get_vocab_size() # Keep in line with internal tokenizer apis. + tokenizer.n_words = tokenizer.get_vocab_size() tokenizer.decode_token = lambda token: tokenizer.decode([token]) original_encode = tokenizer.encode tokenizer.encode = lambda prompt, **kwargs: original_encode(prompt).ids From e3352fa8b94b21d09fb313b68f2b548c1d61ea07 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 12 Feb 2025 11:59:35 -0800 Subject: [PATCH 6/8] Move into class --- examples/models/llama/runner/eager.py | 7 ++++ extension/llm/tokenizer/hf_tokenizer.py | 52 +++++++++++++++++++++++++ extension/llm/tokenizer/utils.py | 15 +------ 3 files changed, 61 insertions(+), 13 deletions(-) create mode 100644 extension/llm/tokenizer/hf_tokenizer.py diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index 54cfc283ae9..a812a59e37b 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -28,6 +28,7 @@ def __init__(self, args): params = json.loads(f.read()) super().__init__( tokenizer_path=args.tokenizer_path, + tokenizer_config_path=args.tokenizer_config_path max_seq_len=args.max_seq_length, max_batch_size=1, use_kv_cache=args.use_kv_cache, @@ -72,6 +73,12 @@ def build_args_parser() -> argparse.ArgumentParser: action="store_true", default=False, help="Have multi-turn chat with the model", + )p + + parser.add_argument( + "--tokenizer_config_path", + type=str, + deafult=None, ) return parser diff --git a/extension/llm/tokenizer/hf_tokenizer.py b/extension/llm/tokenizer/hf_tokenizer.py new file mode 100644 index 00000000000..8e6ab9461a8 --- /dev/null +++ b/extension/llm/tokenizer/hf_tokenizer.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +from typing import List, Optional + +from tokenizers import Tokenizer + + +class HuggingFaceTokenizer: + """ + Tokenizing and encoding/decoding text using the Hugging face tokenizer. + """ + def __init__(self, model_path: str, config_path: Optional[str] = None): + """ + Initializes the Tokenizer with a tokenizer.json from HuggingFace. + + Args: + model_path (str): The path to the Tiktoken model file. + """ + assert os.path.isfile(model_path), model_path + + self.model = tokenizer = Tokenizer.from_file(model_path) + + self.n_words: int = tokenizer.get_vocab_size() + if config_path: + with open(config_path) as f: + tokenizer_config = json.load(f) + self.bos_id = self.model.token_to_id(tokenizer_config["bos_token"])if tokenizer_config["bos_token"] else None + self.eos_id = self.model.token_to_id(tokenizer_config["eos_token"]) + else: # Fallback guess. + self.bos_id = self.model.token_to_id("<|begin_of_text|>") + self.eos_id = self.model.token_to_id("<|endoftext|>") + + self.stop_tokens = [ + self.eos_id, + ] + + def encode(self, s: str, *, bos: bool, eos: bool) -> List[int]: + assert type(s) is str + return self.model.encode(s).ids + + def decode(self, t: List[int]) -> str: + return self.model.decode(t) + + def decode_token(self, t: int) -> str: + return self.model.decode([t]) + diff --git a/extension/llm/tokenizer/utils.py b/extension/llm/tokenizer/utils.py index b6efa2e7bc2..307372ae72f 100644 --- a/extension/llm/tokenizer/utils.py +++ b/extension/llm/tokenizer/utils.py @@ -15,20 +15,9 @@ def get_tokenizer(tokenizer_path: str, tokenizer_config_path: Optional[str] = None): if tokenizer_path.endswith(".json"): - from tokenizers import Tokenizer + from executorch.extension.llm.tokenizer.hf_tokenizer import HuggingFaceTokenizer - tokenizer = Tokenizer.from_file(tokenizer_path) - - # Keep in line with internal tokenizer apis. - tokenizer.n_words = tokenizer.get_vocab_size() - tokenizer.decode_token = lambda token: tokenizer.decode([token]) - original_encode = tokenizer.encode - tokenizer.encode = lambda prompt, **kwargs: original_encode(prompt).ids - - if tokenizer_config_path: - with open(tokenizer_config_path) as f: - tokenizer_config = json.load(f) - tokenizer.eos_id = tokenizer.token_to_id(tokenizer_config["eos_token"]) + tokenizer = HuggingFaceTokenizer(tokenizer_path, tokenizer_config_path) else: try: tokenizer = SentencePieceTokenizer(model_path=str(tokenizer_path)) From 01ed13c397e83f89adb74d1d5fa61f7f9e9eed0d Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 12 Feb 2025 13:50:18 -0800 Subject: [PATCH 7/8] Lint --- examples/models/llama/runner/eager.py | 6 +++--- extension/llm/tokenizer/hf_tokenizer.py | 8 ++++++-- extension/llm/tokenizer/utils.py | 1 - 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index a812a59e37b..fac0af891b3 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -28,7 +28,7 @@ def __init__(self, args): params = json.loads(f.read()) super().__init__( tokenizer_path=args.tokenizer_path, - tokenizer_config_path=args.tokenizer_config_path + tokenizer_config_path=args.tokenizer_config_path, max_seq_len=args.max_seq_length, max_batch_size=1, use_kv_cache=args.use_kv_cache, @@ -73,12 +73,12 @@ def build_args_parser() -> argparse.ArgumentParser: action="store_true", default=False, help="Have multi-turn chat with the model", - )p + ) parser.add_argument( "--tokenizer_config_path", type=str, - deafult=None, + default=None, ) return parser diff --git a/extension/llm/tokenizer/hf_tokenizer.py b/extension/llm/tokenizer/hf_tokenizer.py index 8e6ab9461a8..cc2e2cfcfe3 100644 --- a/extension/llm/tokenizer/hf_tokenizer.py +++ b/extension/llm/tokenizer/hf_tokenizer.py @@ -15,6 +15,7 @@ class HuggingFaceTokenizer: """ Tokenizing and encoding/decoding text using the Hugging face tokenizer. """ + def __init__(self, model_path: str, config_path: Optional[str] = None): """ Initializes the Tokenizer with a tokenizer.json from HuggingFace. @@ -30,7 +31,11 @@ def __init__(self, model_path: str, config_path: Optional[str] = None): if config_path: with open(config_path) as f: tokenizer_config = json.load(f) - self.bos_id = self.model.token_to_id(tokenizer_config["bos_token"])if tokenizer_config["bos_token"] else None + self.bos_id = ( + self.model.token_to_id(tokenizer_config["bos_token"]) + if tokenizer_config["bos_token"] + else None + ) self.eos_id = self.model.token_to_id(tokenizer_config["eos_token"]) else: # Fallback guess. self.bos_id = self.model.token_to_id("<|begin_of_text|>") @@ -49,4 +54,3 @@ def decode(self, t: List[int]) -> str: def decode_token(self, t: int) -> str: return self.model.decode([t]) - diff --git a/extension/llm/tokenizer/utils.py b/extension/llm/tokenizer/utils.py index 307372ae72f..5377048438a 100644 --- a/extension/llm/tokenizer/utils.py +++ b/extension/llm/tokenizer/utils.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import json from typing import Optional from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken From b0990dad99e866cf8f272b959ba1ed5466e8f229 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 12 Feb 2025 15:00:22 -0800 Subject: [PATCH 8/8] Mengwei pr rev --- examples/models/llama/runner/eager.py | 1 + examples/models/llama/runner/native.py | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index fac0af891b3..0b842a8f976 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -79,6 +79,7 @@ def build_args_parser() -> argparse.ArgumentParser: "--tokenizer_config_path", type=str, default=None, + help="Path to an accompanying tokenizer_config.json, which provides metadata for the main tokenizer.json", ) return parser diff --git a/examples/models/llama/runner/native.py b/examples/models/llama/runner/native.py index d71815edba5..6d5d4730844 100644 --- a/examples/models/llama/runner/native.py +++ b/examples/models/llama/runner/native.py @@ -98,6 +98,7 @@ def build_args_parser() -> argparse.ArgumentParser: "--tokenizer_config", type=str, default=None, + help="Path to an accompanying tokenizer_config.json, which provides metadata for the main tokenizer.json", ) parser.add_argument(