11import os
2+ from pathlib import Path
23import sys
34import uuid
45import time
2324
2425from . import llama_cpp
2526from .llama_types import *
27+ from .llama_grammar import LlamaGrammar
2628
2729import numpy as np
2830import numpy .typing as npt
@@ -223,6 +225,7 @@ def __init__(
223225 tensor_split : Optional [List [float ]] = None ,
224226 rope_freq_base : float = 10000.0 ,
225227 rope_freq_scale : float = 1.0 ,
228+ grammar : Optional [Union [str , Path ]] = None ,
226229 n_gqa : Optional [int ] = None , # (TEMPORARY) must be 8 for llama2 70b
227230 rms_norm_eps : Optional [float ] = None , # (TEMPORARY)
228231 verbose : bool = True ,
@@ -248,6 +251,7 @@ def __init__(
248251 tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
249252 rope_freq_base: Base frequency for rope sampling.
250253 rope_freq_scale: Scale factor for rope sampling.
254+ grammar: Path to a BNF grammar file to use for grammar based sampling.
251255 verbose: Print verbose output to stderr.
252256
253257 Raises:
@@ -358,6 +362,12 @@ def __init__(
358362 self .scores : npt .NDArray [np .single ] = np .ndarray (
359363 (n_ctx , self ._n_vocab ), dtype = np .single
360364 )
365+ if grammar is not None :
366+ self .grammar = LlamaGrammar .from_file (
367+ grammar
368+ ) # type: Optional[LlamaGrammar]
369+ else :
370+ self .grammar = None
361371
362372 @property
363373 def _input_ids (self ) -> npt .NDArray [np .intc ]:
@@ -542,8 +552,16 @@ def _sample(
542552 )
543553 if not penalize_nl :
544554 candidates .data [self ._token_nl ].logit = llama_cpp .c_float (nl_logit )
555+
556+ if self .grammar is not None :
557+ llama_cpp .llama_sample_grammar (
558+ ctx = self .ctx ,
559+ candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
560+ grammar = self .grammar .grammar ,
561+ )
562+
545563 if temp .value == 0.0 :
546- return llama_cpp .llama_sample_token_greedy (
564+ id = llama_cpp .llama_sample_token_greedy (
547565 ctx = self .ctx ,
548566 candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
549567 )
@@ -555,7 +573,7 @@ def _sample(
555573 candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
556574 temp = temp ,
557575 )
558- return llama_cpp .llama_sample_token_mirostat (
576+ id = llama_cpp .llama_sample_token_mirostat (
559577 ctx = self .ctx ,
560578 candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
561579 tau = mirostat_tau ,
@@ -570,7 +588,7 @@ def _sample(
570588 candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
571589 temp = temp ,
572590 )
573- return llama_cpp .llama_sample_token_mirostat_v2 (
591+ id = llama_cpp .llama_sample_token_mirostat_v2 (
574592 ctx = self .ctx ,
575593 candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
576594 tau = mirostat_tau ,
@@ -607,10 +625,17 @@ def _sample(
607625 candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
608626 temp = temp ,
609627 )
610- return llama_cpp .llama_sample_token (
628+ id = llama_cpp .llama_sample_token (
611629 ctx = self .ctx ,
612630 candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
613631 )
632+ if self .grammar is not None :
633+ llama_cpp .llama_grammar_accept_token (
634+ ctx = self .ctx ,
635+ grammar = self .grammar .grammar ,
636+ token = llama_cpp .ctypes .c_int (id ),
637+ )
638+ return id
614639
615640 def sample (
616641 self ,
@@ -1509,6 +1534,9 @@ def __del__(self):
15091534 if self .ctx is not None :
15101535 llama_cpp .llama_free (self .ctx )
15111536 self .ctx = None
1537+ if self .grammar is not None :
1538+ llama_cpp .llama_grammar_free (self .grammar .grammar )
1539+ self .grammar = None
15121540
15131541 def __getstate__ (self ):
15141542 return dict (
0 commit comments