@@ -228,7 +228,7 @@ def __init__(
228228 rope_freq_scale : float = 1.0 ,
229229 n_gqa : Optional [int ] = None , # (TEMPORARY) must be 8 for llama2 70b
230230 rms_norm_eps : Optional [float ] = None , # (TEMPORARY)
231- mul_mat_q : Optional [bool ] = None , # (TEMPORARY)
231+ mul_mat_q : Optional [bool ] = None ,
232232 verbose : bool = True ,
233233 ):
234234 """Load a llama.cpp model from `model_path`.
@@ -290,11 +290,6 @@ def __init__(
290290 self .params .rope_freq_base = rope_freq_base
291291 self .params .rope_freq_scale = rope_freq_scale
292292
293- if n_gqa is not None :
294- self .params .n_gqa = n_gqa
295-
296- if rms_norm_eps is not None :
297- self .params .rms_norm_eps = rms_norm_eps
298293
299294 if mul_mat_q is not None :
300295 self .params .mul_mat_q = mul_mat_q
@@ -371,8 +366,8 @@ def __init__(
371366 sorted = sorted ,
372367 )
373368 self ._candidates = candidates
374- self ._token_nl = Llama .token_nl ()
375- self ._token_eos = Llama .token_eos ()
369+ self ._token_nl = self .token_nl ()
370+ self ._token_eos = self .token_eos ()
376371 self ._candidates_data_id = np .arange (self ._n_vocab , dtype = np .intc ) # type: ignore
377372 self ._candidates_data_p = np .zeros (self ._n_vocab , dtype = np .single )
378373
@@ -413,11 +408,11 @@ def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
413408 Returns:
414409 A list of tokens.
415410 """
416- assert self .ctx is not None
411+ assert self .model is not None
417412 n_ctx = self ._n_ctx
418413 tokens = (llama_cpp .llama_token * n_ctx )()
419- n_tokens = llama_cpp .llama_tokenize (
420- self .ctx ,
414+ n_tokens = llama_cpp .llama_tokenize_with_model (
415+ self .model ,
421416 text ,
422417 tokens ,
423418 llama_cpp .c_int (n_ctx ),
@@ -426,8 +421,8 @@ def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
426421 if n_tokens < 0 :
427422 n_tokens = abs (n_tokens )
428423 tokens = (llama_cpp .llama_token * n_tokens )()
429- n_tokens = llama_cpp .llama_tokenize (
430- self .ctx ,
424+ n_tokens = llama_cpp .llama_tokenize_with_model (
425+ self .model ,
431426 text ,
432427 tokens ,
433428 llama_cpp .c_int (n_tokens ),
@@ -448,13 +443,19 @@ def detokenize(self, tokens: List[int]) -> bytes:
448443 Returns:
449444 The detokenized string.
450445 """
451- assert self .ctx is not None
446+ assert self .model is not None
452447 output = b""
448+ size = 8
449+ buffer = (ctypes .c_char * size )()
453450 for token in tokens :
454- output + = llama_cpp .llama_token_to_str (
455- self .ctx , llama_cpp .llama_token (token )
451+ n = llama_cpp .llama_token_to_str_with_model (
452+ self .model , llama_cpp .llama_token (token ), buffer , size
456453 )
457- return output
454+ assert n <= size
455+ output += bytes (buffer [:n ])
456+ # NOTE: Llama1 models automatically added a space at the start of the prompt
457+ # this line removes a leading space if the first token is a beginning of sentence token
458+ return output [1 :] if len (tokens ) > 0 and tokens [0 ] == self .token_bos () else output
458459
459460 def set_cache (self , cache : Optional [BaseLlamaCache ]):
460461 """Set the cache.
@@ -885,7 +886,7 @@ def _create_completion(
885886 created : int = int (time .time ())
886887 completion_tokens : List [int ] = []
887888 # Add blank space to start of prompt to match OG llama tokenizer
888- prompt_tokens : List [int ] = self .tokenize (b" " + prompt .encode ("utf-8" ))
889+ prompt_tokens : List [int ] = self .tokenize (prompt .encode ("utf-8" )) if prompt != "" else [ self . token_bos ()]
889890 text : bytes = b""
890891 returned_tokens : int = 0
891892 stop = (
@@ -1581,13 +1582,7 @@ def __getstate__(self):
15811582 lora_base = self .lora_base ,
15821583 lora_path = self .lora_path ,
15831584 tensor_split = self .tensor_split ,
1584- ### TEMPORARY ###
1585- n_gqa = self .params .n_gqa ,
1586- rms_norm_eps = self .params .rms_norm_eps ,
1587- ### TEMPORARY ###
1588- ### DEPRECATED ###
1589- n_parts = self .n_parts ,
1590- ### DEPRECATED ###
1585+ mul_mat_q = self .params .mul_mat_q ,
15911586 )
15921587
15931588 def __setstate__ (self , state ):
@@ -1609,14 +1604,8 @@ def __setstate__(self, state):
16091604 lora_base = state ["lora_base" ],
16101605 lora_path = state ["lora_path" ],
16111606 tensor_split = state ["tensor_split" ],
1607+ mul_mat_q = state ["mul_mat_q" ],
16121608 verbose = state ["verbose" ],
1613- ### TEMPORARY ###
1614- n_gqa = state ["n_gqa" ],
1615- rms_norm_eps = state ["rms_norm_eps" ],
1616- ### TEMPORARY ###
1617- ### DEPRECATED ###
1618- n_parts = state ["n_parts" ],
1619- ### DEPRECATED ###
16201609 )
16211610
16221611 def save_state (self ) -> LlamaState :
@@ -1681,20 +1670,20 @@ def tokenizer(self) -> "LlamaTokenizer":
16811670 assert self .ctx is not None
16821671 return LlamaTokenizer (self )
16831672
1684- @staticmethod
1685- def token_eos () -> int :
1673+ def token_eos (self ) -> int :
16861674 """Return the end-of-sequence token."""
1687- return llama_cpp .llama_token_eos ()
1675+ assert self .ctx is not None
1676+ return llama_cpp .llama_token_eos (self .ctx )
16881677
1689- @staticmethod
1690- def token_bos () -> int :
1678+ def token_bos (self ) -> int :
16911679 """Return the beginning-of-sequence token."""
1692- return llama_cpp .llama_token_bos ()
1680+ assert self .ctx is not None
1681+ return llama_cpp .llama_token_bos (self .ctx )
16931682
1694- @staticmethod
1695- def token_nl () -> int :
1683+ def token_nl (self ) -> int :
16961684 """Return the newline token."""
1697- return llama_cpp .llama_token_nl ()
1685+ assert self .ctx is not None
1686+ return llama_cpp .llama_token_nl (self .ctx )
16981687
16991688 @staticmethod
17001689 def logits_to_logprobs (logits : List [float ]) -> List [float ]:
0 commit comments