diff --git a/examples/models/llama/install_requirements.sh b/examples/models/llama/install_requirements.sh index f72d0ea3cc2..cca6ede1d79 100755 --- a/examples/models/llama/install_requirements.sh +++ b/examples/models/llama/install_requirements.sh @@ -5,14 +5,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# Install sentencepiece for llama tokenizer. +# Install tiktoken for tokenizer. +# Install tokenizers for hf .json tokenizer. # Install snakeviz for cProfile flamegraph -# Install sentencepiece for llama tokenizer -pip install snakeviz sentencepiece - -# Install lm-eval for Model Evaluation with lm-evalution-harness -# Install tiktoken for tokenizer -pip install lm_eval==0.4.5 -pip install tiktoken blobfile +# Install lm-eval for Model Evaluation with lm-evalution-harness. +pip install tiktoken sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile # Call the install helper for further setup python examples/models/llama/install_requirement_helper.py diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index 54cfc283ae9..0b842a8f976 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -28,6 +28,7 @@ def __init__(self, args): params = json.loads(f.read()) super().__init__( tokenizer_path=args.tokenizer_path, + tokenizer_config_path=args.tokenizer_config_path, max_seq_len=args.max_seq_length, max_batch_size=1, use_kv_cache=args.use_kv_cache, @@ -74,6 +75,13 @@ def build_args_parser() -> argparse.ArgumentParser: help="Have multi-turn chat with the model", ) + parser.add_argument( + "--tokenizer_config_path", + type=str, + default=None, + help="Path to an accompanying tokenizer_config.json, which provides metadata for the main tokenizer.json", + ) + return parser diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 891ce20db3e..3e9ceb34af5 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -48,7 +48,9 @@ def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int: class LlamaRunner(ABC): def __init__( self, + *, tokenizer_path: str, + tokenizer_config_path: Optional[str] = None, max_seq_len: int, max_batch_size: int, use_kv_cache: bool, @@ -59,19 +61,23 @@ def __init__( Constructor. Args: - tokenizer_path: path to tokenizer.model file. - max_seq_len: max length of the output sequence, after which the output will be clipped. - max_batch_size: max batch size. - use_kv_cache: whether to use a KV cache. - vocab_size: number of items in the vocab. - device: device to run the runner on. + tokenizer_path: path to tokenizer.model file. + max_seq_len: max length of the output sequence, after which the output will be clipped. + max_batch_size: max batch size. + use_kv_cache: whether to use a KV cache. + vocab_size: number of items in the vocab. + device: device to run the runner on. """ self.max_seq_len = max_seq_len self.max_batch_size = max_batch_size self.use_kv_cache = use_kv_cache - self.tokenizer = get_tokenizer(tokenizer_path) + self.tokenizer = get_tokenizer(tokenizer_path, tokenizer_config_path) self.device = device - assert vocab_size == self.tokenizer.n_words + # For some models like qwen, mismatch is acceptable: https://github.com/QwenLM/Qwen2.5/issues/466#issuecomment-2146759706 + if vocab_size != self.tokenizer.n_words: + print( + "Warning - given vocab_size in params is unequal to tokenizer vocab size." + ) @abstractmethod def forward( diff --git a/examples/models/llama/runner/native.py b/examples/models/llama/runner/native.py index a6b055ced95..6d5d4730844 100644 --- a/examples/models/llama/runner/native.py +++ b/examples/models/llama/runner/native.py @@ -37,6 +37,7 @@ def __init__(self, args): params = json.loads(f.read()) super().__init__( tokenizer_path=args.tokenizer, + tokenizer_config_path=args.tokenizer_config, max_seq_len=args.max_len, max_batch_size=1, use_kv_cache=args.kv_cache, @@ -56,6 +57,14 @@ def forward( )[0] +def validate_args(args) -> None: + if args.tokenizer and args.tokenizer.endswith(".json"): + if not args.tokenizer_config: + raise TypeError( + "Json tokenizers require an accompanying tokenizer config (--tokenizer_config) to be specified." + ) + + def build_args_parser() -> argparse.ArgumentParser: # TODO: merge these with build_args_parser from export_llama_lib. parser = argparse.ArgumentParser() @@ -85,6 +94,13 @@ def build_args_parser() -> argparse.ArgumentParser: default=None, ) + parser.add_argument( + "--tokenizer_config", + type=str, + default=None, + help="Path to an accompanying tokenizer_config.json, which provides metadata for the main tokenizer.json", + ) + parser.add_argument( "--prompt", type=str, @@ -116,6 +132,7 @@ def build_args_parser() -> argparse.ArgumentParser: def main() -> None: parser = build_args_parser() args = parser.parse_args() + validate_args(args) runner = NativeLlamaRunner(args) generated_tokens = runner.text_completion( prompt=args.prompt, diff --git a/examples/models/llama3_2_vision/runner/generation.py b/examples/models/llama3_2_vision/runner/generation.py index e17760fd852..88a6a44a535 100644 --- a/examples/models/llama3_2_vision/runner/generation.py +++ b/examples/models/llama3_2_vision/runner/generation.py @@ -13,6 +13,7 @@ class TorchTuneLlamaRunner(LlamaRunner): def __init__( self, + *, tokenizer_path: str, max_seq_len: int, max_batch_size: int, @@ -21,12 +22,12 @@ def __init__( device: str = "cpu", ): super().__init__( - tokenizer_path, - max_seq_len, - max_batch_size, - use_kv_cache, - vocab_size, - device, + tokenizer_path=tokenizer_path, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + use_kv_cache=use_kv_cache, + vocab_size=vocab_size, + device=device, ) self.causal_mask = torch.tril( diff --git a/extension/llm/tokenizer/hf_tokenizer.py b/extension/llm/tokenizer/hf_tokenizer.py new file mode 100644 index 00000000000..cc2e2cfcfe3 --- /dev/null +++ b/extension/llm/tokenizer/hf_tokenizer.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +from typing import List, Optional + +from tokenizers import Tokenizer + + +class HuggingFaceTokenizer: + """ + Tokenizing and encoding/decoding text using the Hugging face tokenizer. + """ + + def __init__(self, model_path: str, config_path: Optional[str] = None): + """ + Initializes the Tokenizer with a tokenizer.json from HuggingFace. + + Args: + model_path (str): The path to the Tiktoken model file. + """ + assert os.path.isfile(model_path), model_path + + self.model = tokenizer = Tokenizer.from_file(model_path) + + self.n_words: int = tokenizer.get_vocab_size() + if config_path: + with open(config_path) as f: + tokenizer_config = json.load(f) + self.bos_id = ( + self.model.token_to_id(tokenizer_config["bos_token"]) + if tokenizer_config["bos_token"] + else None + ) + self.eos_id = self.model.token_to_id(tokenizer_config["eos_token"]) + else: # Fallback guess. + self.bos_id = self.model.token_to_id("<|begin_of_text|>") + self.eos_id = self.model.token_to_id("<|endoftext|>") + + self.stop_tokens = [ + self.eos_id, + ] + + def encode(self, s: str, *, bos: bool, eos: bool) -> List[int]: + assert type(s) is str + return self.model.encode(s).ids + + def decode(self, t: List[int]) -> str: + return self.model.decode(t) + + def decode_token(self, t: int) -> str: + return self.model.decode([t]) diff --git a/extension/llm/tokenizer/utils.py b/extension/llm/tokenizer/utils.py index 126a1203274..5377048438a 100644 --- a/extension/llm/tokenizer/utils.py +++ b/extension/llm/tokenizer/utils.py @@ -4,16 +4,23 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Optional + from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken from executorch.extension.llm.tokenizer.tokenizer import ( Tokenizer as SentencePieceTokenizer, ) -def get_tokenizer(tokenizer_path): - try: - tokenizer = SentencePieceTokenizer(model_path=str(tokenizer_path)) - except Exception: - print("Using Tiktokenizer") - tokenizer = Tiktoken(model_path=str(tokenizer_path)) +def get_tokenizer(tokenizer_path: str, tokenizer_config_path: Optional[str] = None): + if tokenizer_path.endswith(".json"): + from executorch.extension.llm.tokenizer.hf_tokenizer import HuggingFaceTokenizer + + tokenizer = HuggingFaceTokenizer(tokenizer_path, tokenizer_config_path) + else: + try: + tokenizer = SentencePieceTokenizer(model_path=str(tokenizer_path)) + except Exception: + print("Using Tiktokenizer") + tokenizer = Tiktoken(model_path=str(tokenizer_path)) return tokenizer