@@ -48,7 +48,9 @@ def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:
4848class LlamaRunner (ABC ):
4949 def __init__ (
5050 self ,
51+ * ,
5152 tokenizer_path : str ,
53+ tokenizer_config_path : Optional [str ] = None ,
5254 max_seq_len : int ,
5355 max_batch_size : int ,
5456 use_kv_cache : bool ,
@@ -59,20 +61,23 @@ def __init__(
5961 Constructor.
6062
6163 Args:
62- tokenizer_path: path to tokenizer.model file.
63- max_seq_len: max length of the output sequence, after which the output will be clipped.
64- max_batch_size: max batch size.
65- use_kv_cache: whether to use a KV cache.
66- vocab_size: number of items in the vocab.
67- device: device to run the runner on.
64+ tokenizer_path: path to tokenizer.model file.
65+ max_seq_len: max length of the output sequence, after which the output will be clipped.
66+ max_batch_size: max batch size.
67+ use_kv_cache: whether to use a KV cache.
68+ vocab_size: number of items in the vocab.
69+ device: device to run the runner on.
6870 """
6971 self .max_seq_len = max_seq_len
7072 self .max_batch_size = max_batch_size
7173 self .use_kv_cache = use_kv_cache
72- self .tokenizer = get_tokenizer (tokenizer_path )
74+ self .tokenizer = get_tokenizer (tokenizer_path , tokenizer_config_path )
7375 self .device = device
74- # For qwen anything above 151646 is "useless": https://github.com/QwenLM/Qwen2.5/issues/466#issuecomment-2146759706
75- # assert vocab_size == self.tokenizer.n_words
76+ # For some models like qwen, mismatch is acceptable: https://github.com/QwenLM/Qwen2.5/issues/466#issuecomment-2146759706
77+ if vocab_size != self .tokenizer .n_words :
78+ print (
79+ "Warning - given vocab_size in params is unequal to tokenizer vocab size."
80+ )
7681
7782 @abstractmethod
7883 def forward (
@@ -102,8 +107,7 @@ def generate( # noqa: C901
102107 )
103108
104109 current_token = next_token (logits , temperature , top_p )
105- # print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
106- print (f"{ self .tokenizer .decode ([current_token ])} " , end = "" , flush = True )
110+ print (f"{ self .tokenizer .decode_token (current_token )} " , end = "" , flush = True )
107111 tokens = prompt_tokens + [current_token ]
108112
109113 while len (tokens ) < max_seq_len :
@@ -133,8 +137,7 @@ def generate( # noqa: C901
133137 ):
134138 break
135139
136- # print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
137- print (f"{ self .tokenizer .decode ([current_token ])} " , end = "" , flush = True )
140+ print (f"{ self .tokenizer .decode_token (current_token )} " , end = "" , flush = True )
138141 print ("\n " )
139142
140143 return tokens if echo else tokens [len (prompt_tokens ) :]
@@ -200,9 +203,7 @@ def chat_completion(
200203 # prompt_tokens = self.tokenizer.encode(
201204 # self._format_prompt(prompt), bos=True, eos=False
202205 # )
203- prompt_tokens = self .tokenizer .encode (
204- self ._format_prompt (prompt )
205- ).ids
206+ prompt_tokens = self .tokenizer .encode (self ._format_prompt (prompt )).ids
206207 generated_tokens = self .generate (
207208 prompt_tokens = pre_stop_token + prompt_tokens ,
208209 max_seq_len = max_seq_len ,
0 commit comments