@@ -159,11 +159,13 @@ class llama_token_data_array(Structure):
159159
160160
161161# struct llama_context_params {
162- # uint32_t seed; // RNG seed, -1 for random
163- # int32_t n_ctx; // text context
164- # int32_t n_batch; // prompt processing batch size
165- # int32_t n_gpu_layers; // number of layers to store in VRAM
166- # int32_t main_gpu; // the GPU that is used for scratch and small tensors
162+ # uint32_t seed; // RNG seed, -1 for random
163+ # int32_t n_ctx; // text context
164+ # int32_t n_batch; // prompt processing batch size
165+ # int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams)
166+ # int32_t n_gpu_layers; // number of layers to store in VRAM
167+ # int32_t main_gpu; // the GPU that is used for scratch and small tensors
168+ #
167169# const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
168170
169171# // ref: https://github.com/ggerganov/llama.cpp/pull/2054
@@ -190,6 +192,7 @@ class llama_context_params(Structure):
190192 ("seed" , c_uint32 ),
191193 ("n_ctx" , c_int32 ),
192194 ("n_batch" , c_int32 ),
195+ ("n_gqa" , c_int32 ),
193196 ("n_gpu_layers" , c_int32 ),
194197 ("main_gpu" , c_int32 ),
195198 ("tensor_split" , POINTER (c_float )),
@@ -265,6 +268,57 @@ class llama_model_quantize_params(Structure):
265268 ]
266269
267270
271+ # // grammar types
272+ # struct llama_grammar;
273+ llama_grammar_p = c_void_p
274+
275+ # // grammar element type
276+ # enum llama_gretype {
277+ # // end of rule definition
278+ # LLAMA_GRETYPE_END = 0,
279+
280+ # // start of alternate definition for rule
281+ # LLAMA_GRETYPE_ALT = 1,
282+
283+ # // non-terminal element: reference to rule
284+ # LLAMA_GRETYPE_RULE_REF = 2,
285+
286+ # // terminal element: character (code point)
287+ # LLAMA_GRETYPE_CHAR = 3,
288+
289+ # // inverse char(s) ([^a], [^a-b] [^abc])
290+ # LLAMA_GRETYPE_CHAR_NOT = 4,
291+
292+ # // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
293+ # // be an inclusive range ([a-z])
294+ # LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
295+
296+ # // modifies a preceding LLAMA_GRETYPE_CHAR or
297+ # // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
298+ # LLAMA_GRETYPE_CHAR_ALT = 6,
299+ # };
300+ LLAMA_GRETYPE_END = c_int (0 )
301+ LLAMA_GRETYPE_ALT = c_int (1 )
302+ LLAMA_GRETYPE_RULE_REF = c_int (2 )
303+ LLAMA_GRETYPE_CHAR = c_int (3 )
304+ LLAMA_GRETYPE_CHAR_NOT = c_int (4 )
305+ LLAMA_GRETYPE_CHAR_RNG_UPPER = c_int (5 )
306+ LLAMA_GRETYPE_CHAR_ALT = c_int (6 )
307+
308+
309+ # typedef struct llama_grammar_element {
310+ # enum llama_gretype type;
311+ # uint32_t value; // Unicode code point or rule ID
312+ # } llama_grammar_element;
313+ class llama_grammar_element (Structure ):
314+ _fields_ = [
315+ ("type" , c_int ),
316+ ("value" , c_uint32 ),
317+ ]
318+
319+
320+ llama_grammar_element_p = POINTER (llama_grammar_element )
321+
268322# // performance timing information
269323# struct llama_timings {
270324# double t_start_ms;
@@ -871,6 +925,37 @@ def llama_token_nl() -> int:
871925_lib .llama_token_nl .restype = llama_token
872926
873927
928+ # // Grammar
929+ # //
930+ # LLAMA_API struct llama_grammar * llama_grammar_init(
931+ # const llama_grammar_element ** rules,
932+ # size_t n_rules,
933+ # size_t start_rule_index);
934+ def llama_grammar_init (
935+ rules , # type: Array[llama_grammar_element_p] # type: ignore
936+ n_rules : c_size_t ,
937+ start_rule_index : c_size_t ,
938+ ) -> llama_grammar_p :
939+ return _lib .llama_grammar_init (rules , n_rules , start_rule_index )
940+
941+
942+ _lib .llama_grammar_init .argtypes = [
943+ POINTER (llama_grammar_element_p ),
944+ c_size_t ,
945+ c_size_t ,
946+ ]
947+ _lib .llama_grammar_init .restype = llama_grammar_p
948+
949+
950+ # LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
951+ def llama_grammar_free (grammar : llama_grammar_p ):
952+ return _lib .llama_grammar_free (grammar )
953+
954+
955+ _lib .llama_grammar_free .argtypes = [llama_grammar_p ]
956+ _lib .llama_grammar_free .restype = None
957+
958+
874959# Sampling functions
875960
876961
0 commit comments