2323
2424from . import llama_cpp
2525from .llama_types import *
26+ from .llama_grammar import LlamaGrammar
2627
2728import numpy as np
2829import numpy .typing as npt
2930
31+ from .utils import suppress_stdout_stderr
3032
3133class BaseLlamaCache (ABC ):
3234 """Base cache class for a llama.cpp model."""
@@ -231,7 +233,8 @@ def __init__(
231233 rope_freq_base : float = 10000.0 ,
232234 rope_freq_scale : float = 1.0 ,
233235 n_gqa : Optional [int ] = None , # (TEMPORARY) must be 8 for llama2 70b
234- rms_norm_eps : Optional [float ] = None , # (TEMPORARY)
236+ rms_norm_eps : Optional [float ] = None , # (TEMPORARY)
237+ mul_mat_q : Optional [bool ] = None , # (TEMPORARY)
235238 verbose : bool = True ,
236239 ):
237240 """Load a llama.cpp model from `model_path`.
@@ -241,6 +244,7 @@ def __init__(
241244 n_ctx: Maximum context size.
242245 n_parts: Number of parts to split the model into. If -1, the number of parts is automatically determined.
243246 seed: Random seed. -1 for random.
247+ n_gpu_layers: Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded.
244248 f16_kv: Use half-precision for key/value cache.
245249 logits_all: Return logits for all tokens, not just the last token.
246250 vocab_only: Only load the vocabulary no weights.
@@ -269,7 +273,7 @@ def __init__(
269273
270274 self .params = llama_cpp .llama_context_default_params ()
271275 self .params .n_ctx = n_ctx
272- self .params .n_gpu_layers = n_gpu_layers
276+ self .params .n_gpu_layers = 0x7FFFFFFF if n_gpu_layers == - 1 else n_gpu_layers # 0x7FFFFFFF is INT32 max, will be auto set to all layers
273277 self .params .seed = seed
274278 self .params .f16_kv = f16_kv
275279 self .params .logits_all = logits_all
@@ -280,7 +284,7 @@ def __init__(
280284 self .params .low_vram = low_vram
281285
282286 self .tensor_split = tensor_split
283- self ._c_tensor_split = None
287+ self ._p_tensor_split = None
284288
285289 if self .tensor_split is not None :
286290 # Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
@@ -299,6 +303,9 @@ def __init__(
299303 if rms_norm_eps is not None :
300304 self .params .rms_norm_eps = rms_norm_eps
301305
306+ if mul_mat_q is not None :
307+ self .params .mul_mat_q = mul_mat_q
308+
302309 self .last_n_tokens_size = last_n_tokens_size
303310 self .n_batch = min (n_ctx , n_batch )
304311
@@ -316,12 +323,25 @@ def __init__(
316323 if not os .path .exists (model_path ):
317324 raise ValueError (f"Model path does not exist: { model_path } " )
318325
319- self .model = llama_cpp .llama_load_model_from_file (
320- self .model_path .encode ("utf-8" ), self .params
321- )
326+ if verbose :
327+ self .model = llama_cpp .llama_load_model_from_file (
328+ self .model_path .encode ("utf-8" ), self .params
329+ )
330+ else :
331+ with suppress_stdout_stderr ():
332+ self .model = llama_cpp .llama_load_model_from_file (
333+ self .model_path .encode ("utf-8" ), self .params
334+ )
322335 assert self .model is not None
323336
324- self .ctx = llama_cpp .llama_new_context_with_model (self .model , self .params )
337+ if verbose :
338+ self .ctx = llama_cpp .llama_new_context_with_model (self .model , self .params )
339+ else :
340+ with suppress_stdout_stderr ():
341+ print ("here" )
342+ self .ctx = llama_cpp .llama_new_context_with_model (
343+ self .model , self .params
344+ )
325345
326346 assert self .ctx is not None
327347
@@ -358,8 +378,8 @@ def __init__(
358378 sorted = sorted ,
359379 )
360380 self ._candidates = candidates
361- self ._token_nl = Llama .token_nl ()
362- self ._token_eos = Llama .token_eos ()
381+ self ._token_nl = self .token_nl ()
382+ self ._token_eos = self .token_eos ()
363383 self ._candidates_data_id = np .arange (self ._n_vocab , dtype = np .intc ) # type: ignore
364384 self ._candidates_data_p = np .zeros (self ._n_vocab , dtype = np .single )
365385
@@ -437,10 +457,14 @@ def detokenize(self, tokens: List[int]) -> bytes:
437457 """
438458 assert self .ctx is not None
439459 output = b""
460+ buffer_size = 32
461+ buffer = (ctypes .c_char * buffer_size )()
440462 for token in tokens :
441- output + = llama_cpp .llama_token_to_str (
442- self .ctx , llama_cpp .llama_token (token )
463+ n = llama_cpp .llama_token_to_str (
464+ self .ctx , llama_cpp .llama_token (token ), buffer , buffer_size
443465 )
466+ assert n <= buffer_size
467+ output += bytes (buffer [:n ])
444468 return output
445469
446470 def set_cache (self , cache : Optional [BaseLlamaCache ]):
@@ -506,6 +530,7 @@ def _sample(
506530 mirostat_eta : llama_cpp .c_float ,
507531 penalize_nl : bool = True ,
508532 logits_processor : Optional [LogitsProcessorList ] = None ,
533+ grammar : Optional [LlamaGrammar ] = None ,
509534 ):
510535 assert self .ctx is not None
511536 assert self .n_tokens > 0
@@ -548,8 +573,16 @@ def _sample(
548573 )
549574 if not penalize_nl :
550575 candidates .data [self ._token_nl ].logit = llama_cpp .c_float (nl_logit )
576+
577+ if grammar is not None :
578+ llama_cpp .llama_sample_grammar (
579+ ctx = self .ctx ,
580+ candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
581+ grammar = grammar .grammar ,
582+ )
583+
551584 if temp .value == 0.0 :
552- return llama_cpp .llama_sample_token_greedy (
585+ id = llama_cpp .llama_sample_token_greedy (
553586 ctx = self .ctx ,
554587 candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
555588 )
@@ -561,7 +594,7 @@ def _sample(
561594 candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
562595 temp = temp ,
563596 )
564- return llama_cpp .llama_sample_token_mirostat (
597+ id = llama_cpp .llama_sample_token_mirostat (
565598 ctx = self .ctx ,
566599 candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
567600 tau = mirostat_tau ,
@@ -576,7 +609,7 @@ def _sample(
576609 candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
577610 temp = temp ,
578611 )
579- return llama_cpp .llama_sample_token_mirostat_v2 (
612+ id = llama_cpp .llama_sample_token_mirostat_v2 (
580613 ctx = self .ctx ,
581614 candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
582615 tau = mirostat_tau ,
@@ -613,10 +646,17 @@ def _sample(
613646 candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
614647 temp = temp ,
615648 )
616- return llama_cpp .llama_sample_token (
649+ id = llama_cpp .llama_sample_token (
617650 ctx = self .ctx ,
618651 candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
619652 )
653+ if grammar is not None :
654+ llama_cpp .llama_grammar_accept_token (
655+ ctx = self .ctx ,
656+ grammar = grammar .grammar ,
657+ token = llama_cpp .ctypes .c_int (id ),
658+ )
659+ return id
620660
621661 def sample (
622662 self ,
@@ -632,6 +672,7 @@ def sample(
632672 mirostat_tau : float = 5.0 ,
633673 penalize_nl : bool = True ,
634674 logits_processor : Optional [LogitsProcessorList ] = None ,
675+ grammar : Optional [LlamaGrammar ] = None ,
635676 ):
636677 """Sample a token from the model.
637678
@@ -665,6 +706,7 @@ def sample(
665706 mirostat_eta = llama_cpp .c_float (mirostat_eta ),
666707 penalize_nl = penalize_nl ,
667708 logits_processor = logits_processor ,
709+ grammar = grammar ,
668710 )
669711
670712 def generate (
@@ -683,6 +725,7 @@ def generate(
683725 mirostat_eta : float = 0.1 ,
684726 logits_processor : Optional [LogitsProcessorList ] = None ,
685727 stopping_criteria : Optional [StoppingCriteriaList ] = None ,
728+ grammar : Optional [LlamaGrammar ] = None ,
686729 ) -> Generator [int , Optional [Sequence [int ]], None ]:
687730 """Create a generator of tokens from a prompt.
688731
@@ -704,7 +747,6 @@ def generate(
704747 The generated tokens.
705748 """
706749 assert self .ctx is not None
707-
708750 if reset and len (self ._input_ids ) > 0 :
709751 longest_prefix = 0
710752 for a , b in zip (self ._input_ids , tokens [:- 1 ]):
@@ -722,6 +764,9 @@ def generate(
722764 if reset :
723765 self .reset ()
724766
767+ if grammar is not None :
768+ grammar .reset ()
769+
725770 while True :
726771 self .eval (tokens )
727772 token = self .sample (
@@ -736,6 +781,7 @@ def generate(
736781 mirostat_tau = mirostat_tau ,
737782 mirostat_eta = mirostat_eta ,
738783 logits_processor = logits_processor ,
784+ grammar = grammar ,
739785 )
740786 if stopping_criteria is not None and stopping_criteria (
741787 self ._input_ids , self ._scores [- 1 , :]
@@ -838,6 +884,7 @@ def _create_completion(
838884 model : Optional [str ] = None ,
839885 stopping_criteria : Optional [StoppingCriteriaList ] = None ,
840886 logits_processor : Optional [LogitsProcessorList ] = None ,
887+ grammar : Optional [LlamaGrammar ] = None ,
841888 ) -> Union [Iterator [Completion ], Iterator [CompletionChunk ]]:
842889 assert self .ctx is not None
843890
@@ -915,6 +962,7 @@ def _create_completion(
915962 repeat_penalty = repeat_penalty ,
916963 stopping_criteria = stopping_criteria ,
917964 logits_processor = logits_processor ,
965+ grammar = grammar ,
918966 ):
919967 if token == self ._token_eos :
920968 text = self .detokenize (completion_tokens )
@@ -965,9 +1013,7 @@ def _create_completion(
9651013 for token in remaining_tokens :
9661014 token_end_position += len (self .detokenize ([token ]))
9671015 # Check if stop sequence is in the token
968- if token_end_position >= (
969- remaining_length - first_stop_position
970- ):
1016+ if token_end_position >= (remaining_length - first_stop_position ):
9711017 break
9721018 logprobs_or_none : Optional [CompletionLogprobs ] = None
9731019 if logprobs is not None :
@@ -1261,6 +1307,7 @@ def create_completion(
12611307 model : Optional [str ] = None ,
12621308 stopping_criteria : Optional [StoppingCriteriaList ] = None ,
12631309 logits_processor : Optional [LogitsProcessorList ] = None ,
1310+ grammar : Optional [LlamaGrammar ] = None ,
12641311 ) -> Union [Completion , Iterator [CompletionChunk ]]:
12651312 """Generate text from a prompt.
12661313
@@ -1305,6 +1352,7 @@ def create_completion(
13051352 model = model ,
13061353 stopping_criteria = stopping_criteria ,
13071354 logits_processor = logits_processor ,
1355+ grammar = grammar
13081356 )
13091357 if stream :
13101358 chunks : Iterator [CompletionChunk ] = completion_or_chunks
@@ -1334,6 +1382,7 @@ def __call__(
13341382 model : Optional [str ] = None ,
13351383 stopping_criteria : Optional [StoppingCriteriaList ] = None ,
13361384 logits_processor : Optional [LogitsProcessorList ] = None ,
1385+ grammar : Optional [LlamaGrammar ] = None ,
13371386 ) -> Union [Completion , Iterator [CompletionChunk ]]:
13381387 """Generate text from a prompt.
13391388
@@ -1378,6 +1427,7 @@ def __call__(
13781427 model = model ,
13791428 stopping_criteria = stopping_criteria ,
13801429 logits_processor = logits_processor ,
1430+ grammar = grammar ,
13811431 )
13821432
13831433 def _convert_text_completion_to_chat (
@@ -1460,6 +1510,7 @@ def create_chat_completion(
14601510 mirostat_eta : float = 0.1 ,
14611511 model : Optional [str ] = None ,
14621512 logits_processor : Optional [LogitsProcessorList ] = None ,
1513+ grammar : Optional [LlamaGrammar ] = None ,
14631514 ) -> Union [ChatCompletion , Iterator [ChatCompletionChunk ]]:
14641515 """Generate a chat completion from a list of messages.
14651516
@@ -1502,6 +1553,7 @@ def create_chat_completion(
15021553 mirostat_eta = mirostat_eta ,
15031554 model = model ,
15041555 logits_processor = logits_processor ,
1556+ grammar = grammar ,
15051557 )
15061558 if stream :
15071559 chunks : Iterator [CompletionChunk ] = completion_or_chunks # type: ignore
@@ -1511,10 +1563,10 @@ def create_chat_completion(
15111563 return self ._convert_text_completion_to_chat (completion )
15121564
15131565 def __del__ (self ):
1514- if self .model is not None :
1566+ if hasattr ( self , "model" ) and self .model is not None :
15151567 llama_cpp .llama_free_model (self .model )
15161568 self .model = None
1517- if self .ctx is not None :
1569+ if hasattr ( self , "ctx" ) and self .ctx is not None :
15181570 llama_cpp .llama_free (self .ctx )
15191571 self .ctx = None
15201572
@@ -1638,20 +1690,20 @@ def tokenizer(self) -> "LlamaTokenizer":
16381690 assert self .ctx is not None
16391691 return LlamaTokenizer (self )
16401692
1641- @staticmethod
1642- def token_eos () -> int :
1693+ def token_eos (self ) -> int :
16431694 """Return the end-of-sequence token."""
1644- return llama_cpp .llama_token_eos ()
1695+ assert self .ctx is not None
1696+ return llama_cpp .llama_token_eos (self .ctx )
16451697
1646- @staticmethod
1647- def token_bos () -> int :
1698+ def token_bos (self ) -> int :
16481699 """Return the beginning-of-sequence token."""
1649- return llama_cpp .llama_token_bos ()
1700+ assert self .ctx is not None
1701+ return llama_cpp .llama_token_bos (self .ctx )
16501702
1651- @staticmethod
1652- def token_nl () -> int :
1703+ def token_nl (self ) -> int :
16531704 """Return the newline token."""
1654- return llama_cpp .llama_token_nl ()
1705+ assert self .ctx is not None
1706+ return llama_cpp .llama_token_nl (self .ctx )
16551707
16561708 @staticmethod
16571709 def logits_to_logprobs (logits : List [float ]) -> List [float ]:
0 commit comments