@@ -227,7 +227,6 @@ def __init__(
227227 tensor_split : Optional [List [float ]] = None ,
228228 rope_freq_base : float = 10000.0 ,
229229 rope_freq_scale : float = 1.0 ,
230- grammar : Optional [Union [str , Path ]] = None ,
231230 n_gqa : Optional [int ] = None , # (TEMPORARY) must be 8 for llama2 70b
232231 rms_norm_eps : Optional [float ] = None , # (TEMPORARY)
233232 mul_mat_q : Optional [bool ] = None , # (TEMPORARY)
@@ -254,7 +253,6 @@ def __init__(
254253 tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
255254 rope_freq_base: Base frequency for rope sampling.
256255 rope_freq_scale: Scale factor for rope sampling.
257- grammar: Path to a BNF grammar file to use for grammar based sampling.
258256 verbose: Print verbose output to stderr.
259257
260258 Raises:
@@ -383,12 +381,6 @@ def __init__(
383381 self .scores : npt .NDArray [np .single ] = np .ndarray (
384382 (n_ctx , self ._n_vocab ), dtype = np .single
385383 )
386- if grammar is not None :
387- self .grammar = LlamaGrammar .from_file (
388- grammar , verbose = verbose
389- ) # type: Optional[LlamaGrammar]
390- else :
391- self .grammar = None
392384
393385 @property
394386 def _input_ids (self ) -> npt .NDArray [np .intc ]:
@@ -527,6 +519,7 @@ def _sample(
527519 mirostat_eta : llama_cpp .c_float ,
528520 penalize_nl : bool = True ,
529521 logits_processor : Optional [LogitsProcessorList ] = None ,
522+ grammar : Optional [LlamaGrammar ] = None ,
530523 ):
531524 assert self .ctx is not None
532525 assert self .n_tokens > 0
@@ -574,11 +567,11 @@ def _sample(
574567 if not penalize_nl :
575568 candidates .data [self ._token_nl ].logit = llama_cpp .c_float (nl_logit )
576569
577- if self . grammar is not None :
570+ if grammar is not None :
578571 llama_cpp .llama_sample_grammar (
579572 ctx = self .ctx ,
580573 candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
581- grammar = self . grammar .grammar ,
574+ grammar = grammar .grammar ,
582575 )
583576
584577 if temp .value == 0.0 :
@@ -650,10 +643,10 @@ def _sample(
650643 ctx = self .ctx ,
651644 candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
652645 )
653- if self . grammar is not None :
646+ if grammar is not None :
654647 llama_cpp .llama_grammar_accept_token (
655648 ctx = self .ctx ,
656- grammar = self . grammar .grammar ,
649+ grammar = grammar .grammar ,
657650 token = llama_cpp .ctypes .c_int (id ),
658651 )
659652 return id
@@ -672,6 +665,7 @@ def sample(
672665 mirostat_tau : float = 5.0 ,
673666 penalize_nl : bool = True ,
674667 logits_processor : Optional [LogitsProcessorList ] = None ,
668+ grammar : Optional [LlamaGrammar ] = None ,
675669 ):
676670 """Sample a token from the model.
677671
@@ -705,6 +699,7 @@ def sample(
705699 mirostat_eta = llama_cpp .c_float (mirostat_eta ),
706700 penalize_nl = penalize_nl ,
707701 logits_processor = logits_processor ,
702+ grammar = grammar ,
708703 )
709704
710705 def generate (
@@ -723,6 +718,7 @@ def generate(
723718 mirostat_eta : float = 0.1 ,
724719 logits_processor : Optional [LogitsProcessorList ] = None ,
725720 stopping_criteria : Optional [StoppingCriteriaList ] = None ,
721+ grammar : Optional [LlamaGrammar ] = None ,
726722 ) -> Generator [int , Optional [Sequence [int ]], None ]:
727723 """Create a generator of tokens from a prompt.
728724
@@ -761,8 +757,8 @@ def generate(
761757 if reset :
762758 self .reset ()
763759
764- if self . grammar is not None :
765- self . grammar .reset ()
760+ if grammar is not None :
761+ grammar .reset ()
766762
767763 while True :
768764 self .eval (tokens )
@@ -778,6 +774,7 @@ def generate(
778774 mirostat_tau = mirostat_tau ,
779775 mirostat_eta = mirostat_eta ,
780776 logits_processor = logits_processor ,
777+ grammar = grammar ,
781778 )
782779 if stopping_criteria is not None and stopping_criteria (
783780 self ._input_ids .tolist (), self ._scores [- 1 , :].tolist ()
@@ -880,6 +877,7 @@ def _create_completion(
880877 model : Optional [str ] = None ,
881878 stopping_criteria : Optional [StoppingCriteriaList ] = None ,
882879 logits_processor : Optional [LogitsProcessorList ] = None ,
880+ grammar : Optional [LlamaGrammar ] = None ,
883881 ) -> Union [Iterator [Completion ], Iterator [CompletionChunk ]]:
884882 assert self .ctx is not None
885883
@@ -957,6 +955,7 @@ def _create_completion(
957955 repeat_penalty = repeat_penalty ,
958956 stopping_criteria = stopping_criteria ,
959957 logits_processor = logits_processor ,
958+ grammar = grammar ,
960959 ):
961960 if token == self ._token_eos :
962961 text = self .detokenize (completion_tokens )
@@ -1301,6 +1300,7 @@ def create_completion(
13011300 model : Optional [str ] = None ,
13021301 stopping_criteria : Optional [StoppingCriteriaList ] = None ,
13031302 logits_processor : Optional [LogitsProcessorList ] = None ,
1303+ grammar : Optional [LlamaGrammar ] = None ,
13041304 ) -> Union [Completion , Iterator [CompletionChunk ]]:
13051305 """Generate text from a prompt.
13061306
@@ -1345,6 +1345,7 @@ def create_completion(
13451345 model = model ,
13461346 stopping_criteria = stopping_criteria ,
13471347 logits_processor = logits_processor ,
1348+ grammar = grammar
13481349 )
13491350 if stream :
13501351 chunks : Iterator [CompletionChunk ] = completion_or_chunks
@@ -1374,6 +1375,7 @@ def __call__(
13741375 model : Optional [str ] = None ,
13751376 stopping_criteria : Optional [StoppingCriteriaList ] = None ,
13761377 logits_processor : Optional [LogitsProcessorList ] = None ,
1378+ grammar : Optional [LlamaGrammar ] = None ,
13771379 ) -> Union [Completion , Iterator [CompletionChunk ]]:
13781380 """Generate text from a prompt.
13791381
@@ -1418,6 +1420,7 @@ def __call__(
14181420 model = model ,
14191421 stopping_criteria = stopping_criteria ,
14201422 logits_processor = logits_processor ,
1423+ grammar = grammar ,
14211424 )
14221425
14231426 def _convert_text_completion_to_chat (
@@ -1498,6 +1501,7 @@ def create_chat_completion(
14981501 mirostat_eta : float = 0.1 ,
14991502 model : Optional [str ] = None ,
15001503 logits_processor : Optional [LogitsProcessorList ] = None ,
1504+ grammar : Optional [LlamaGrammar ] = None ,
15011505 ) -> Union [ChatCompletion , Iterator [ChatCompletionChunk ]]:
15021506 """Generate a chat completion from a list of messages.
15031507
@@ -1540,6 +1544,7 @@ def create_chat_completion(
15401544 mirostat_eta = mirostat_eta ,
15411545 model = model ,
15421546 logits_processor = logits_processor ,
1547+ grammar = grammar ,
15431548 )
15441549 if stream :
15451550 chunks : Iterator [CompletionChunk ] = completion_or_chunks # type: ignore
0 commit comments