11"""C++ implementation of the llama grammar parser."""
22# flake8: noqa
3- import argparse
43from pathlib import Path
54import sys
65from ctypes import * # type: ignore
1918 overload ,
2019)
2120
22- import llama_cpp
21+ from . import llama_cpp
2322
2423# Type aliases
2524llama_grammar_element = llama_cpp .llama_grammar_element
@@ -41,11 +40,19 @@ class Sentinel:
4140class LlamaGrammar :
4241 """Keeps reference counts of all the arguments, so that they are not
4342 garbage collected by Python."""
43+
44+ def __del__ (self ) -> None :
45+ """Free the grammar pointer when the object is deleted."""
46+ if self .grammar is not None :
47+ llama_cpp .llama_grammar_free (self .grammar )
48+ self .grammar = None
4449
4550 def __init__ (
4651 self ,
4752 parsed_grammar : "parse_state" ,
4853 ) -> None :
54+ """Initialize the grammar pointer from the parsed state."""
55+ self .parsed_grammar = parsed_grammar
4956 grammar_rules = (
5057 parsed_grammar .c_rules ()
5158 ) # type: std.vector[std.vector[llama_grammar_element]]
@@ -69,22 +76,25 @@ def __init__(
6976
7077 self .n_rules = c_size_t (grammar_rules .size ())
7178 self .start_rule_index = c_size_t (parsed_grammar .symbol_ids .at ("root" ))
72- self .grammar = self .init_grammar ()
79+ self ._grammar = llama_cpp .llama_grammar_init (
80+ self .rules , self .n_rules , self .start_rule_index
81+ )
7382
7483 @classmethod
75- def from_string (cls , grammar : str ) -> "LlamaGrammar" :
84+ def from_string (cls , grammar : str , verbose : bool = True ) -> "LlamaGrammar" :
7685 parsed_grammar = parse (const_char_p (grammar )) # type: parse_state
7786 if parsed_grammar .rules .empty ():
7887 raise ValueError (
7988 f"{ cls .from_string .__name__ } : error parsing grammar file: parsed_grammar.rules is empty"
8089 )
81- print (f"{ cls .from_string .__name__ } grammar:" , file = sys .stderr )
82- print_grammar (sys .stdout , parsed_grammar )
83- print (file = sys .stderr )
90+ if verbose :
91+ print (f"{ cls .from_string .__name__ } grammar:" , file = sys .stderr )
92+ print_grammar (sys .stdout , parsed_grammar )
93+ print (file = sys .stderr )
8494 return cls (parsed_grammar )
8595
8696 @classmethod
87- def from_file (cls , file : Union [str , Path ]) -> "LlamaGrammar" :
97+ def from_file (cls , file : Union [str , Path ], verbose : bool = True ) -> "LlamaGrammar" :
8898 try :
8999 with open (file ) as f :
90100 grammar = f .read ()
@@ -94,14 +104,27 @@ def from_file(cls, file: Union[str, Path]) -> "LlamaGrammar":
94104 )
95105
96106 if grammar :
97- return cls .from_string (grammar )
107+ return cls .from_string (grammar , verbose = verbose )
98108
99109 raise ValueError (
100110 f"{ cls .from_file .__name__ } : error parsing grammar file: params_grammer is empty"
101111 )
102112
103- def init_grammar (self ) -> llama_grammar_p :
104- return llama_cpp .llama_grammar_init (
113+ @property
114+ def grammar (self ) -> llama_grammar_p :
115+ if self ._grammar is None :
116+ raise ValueError (
117+ f"{ self .__class__ .__name__ } .grammar: grammar is freed"
118+ )
119+ return self ._grammar
120+
121+ @grammar .setter
122+ def grammar (self , value : Optional [llama_grammar_p ]) -> None :
123+ self ._grammar = value
124+
125+ def reset (self ) -> None :
126+ llama_cpp .llama_grammar_free (self .grammar )
127+ self .grammar = llama_cpp .llama_grammar_init (
105128 self .rules , self .n_rules , self .start_rule_index
106129 )
107130
@@ -1216,82 +1239,4 @@ def print_grammar(file: TextIO, state: parse_state) -> None:
12161239 print (
12171240 f"{ print_grammar .__name__ } : error printing grammar: { err } " ,
12181241 file = sys .stderr ,
1219- )
1220-
1221-
1222- # def convert_to_rules(
1223- # llama_grammar_elements: std.vector[std.vector[llama_grammar_element]],
1224- # ) -> Array[llama_grammar_element_p]:
1225- # """Make an Array object that is used for `llama_grammer_init`"""
1226-
1227- # # Step 1: Convert each list to llama_grammar_element array and get pointer
1228- # element_arrays = [
1229- # (llama_grammar_element * len(subvector))(*subvector)
1230- # for subvector in llama_grammar_elements
1231- # ] # type: List[Array[llama_grammar_element]]
1232-
1233- # # Step 2: Get pointer of each array
1234- # element_array_pointers = [
1235- # cast(subarray, llama_grammar_element_p) for subarray in element_arrays
1236- # ] # type: List[llama_grammar_element_p]
1237-
1238- # # Step 3: Make array of these pointers and get its pointer
1239- # return (llama_grammar_element_p * len(element_array_pointers))(
1240- # *element_array_pointers
1241- # )
1242-
1243-
1244- if __name__ == "__main__" :
1245- parser = argparse .ArgumentParser (
1246- description = "Generate C++ parser from GBNF grammar"
1247- )
1248- parser .add_argument (
1249- "-g" ,
1250- "--grammar" ,
1251- type = str ,
1252- default = "./vendor/llama.cpp/grammars/json.gbnf" ,
1253- help = "path to GBNF grammar file" ,
1254- )
1255-
1256- args = parser .parse_args ()
1257- llama_grammar = LlamaGrammar .from_file (Path (args .grammar ))
1258- llama_grammar_ptr = llama_grammar .init_grammar ()
1259-
1260- # ----- USAGE:
1261- # llama_cpp.llama_sample_grammar(ctx=..., candidates=..., grammar=llama_grammar_p)
1262- # llama_cpp.llama_grammar_accept_token(ctx=..., grammar=llama_grammar_p, token=...)
1263-
1264- # ----- SAMPLE OUTPUT:
1265- # main grammar:
1266- # root ::= object
1267- # object ::= [{] ws object_11 [}] ws
1268- # value ::= object | array | string | number | value_6 ws
1269- # array ::= [[] ws array_15 []] ws
1270- # string ::= ["] string_18 ["] ws
1271- # number ::= number_19 number_25 number_29 ws
1272- # value_6 ::= [t] [r] [u] [e] | [f] [a] [l] [s] [e] | [n] [u] [l] [l]
1273- # ws ::= ws_31
1274- # object_8 ::= string [:] ws value object_10
1275- # object_9 ::= [,] ws string [:] ws value
1276- # object_10 ::= object_9 object_10 |
1277- # object_11 ::= object_8 |
1278- # array_12 ::= value array_14
1279- # array_13 ::= [,] ws value
1280- # array_14 ::= array_13 array_14 |
1281- # array_15 ::= array_12 |
1282- # string_16 ::= [^"\] | [\] string_17
1283- # string_17 ::= ["\/bfnrt] | [u] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]
1284- # string_18 ::= string_16 string_18 |
1285- # number_19 ::= number_20 number_21
1286- # number_20 ::= [-] |
1287- # number_21 ::= [0-9] | [1-9] number_22
1288- # number_22 ::= [0-9] number_22 |
1289- # number_23 ::= [.] number_24
1290- # number_24 ::= [0-9] number_24 | [0-9]
1291- # number_25 ::= number_23 |
1292- # number_26 ::= [eE] number_27 number_28
1293- # number_27 ::= [-+] |
1294- # number_28 ::= [0-9] number_28 | [0-9]
1295- # number_29 ::= number_26 |
1296- # ws_30 ::= [ <U+0009><U+000A>] ws
1297- # ws_31 ::= ws_30 |
1242+ )
0 commit comments