@@ -159,11 +159,14 @@ class llama_token_data_array(Structure):
159159
160160
161161# struct llama_context_params {
162- # uint32_t seed; // RNG seed, -1 for random
163- # int32_t n_ctx; // text context
164- # int32_t n_batch; // prompt processing batch size
165- # int32_t n_gpu_layers; // number of layers to store in VRAM
166- # int32_t main_gpu; // the GPU that is used for scratch and small tensors
162+ # uint32_t seed; // RNG seed, -1 for random
163+ # int32_t n_ctx; // text context
164+ # int32_t n_batch; // prompt processing batch size
165+ # int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams)
166+ # float rms_norm_eps; // rms norm epsilon (TEMP - will be moved to model hparams)
167+ # int32_t n_gpu_layers; // number of layers to store in VRAM
168+ # int32_t main_gpu; // the GPU that is used for scratch and small tensors
169+ #
167170# const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
168171
169172# // ref: https://github.com/ggerganov/llama.cpp/pull/2054
@@ -190,6 +193,8 @@ class llama_context_params(Structure):
190193 ("seed" , c_uint32 ),
191194 ("n_ctx" , c_int32 ),
192195 ("n_batch" , c_int32 ),
196+ ("n_gqa" , c_int32 ),
197+ ("rms_norm_eps" , c_float ),
193198 ("n_gpu_layers" , c_int32 ),
194199 ("main_gpu" , c_int32 ),
195200 ("tensor_split" , POINTER (c_float )),
@@ -265,6 +270,57 @@ class llama_model_quantize_params(Structure):
265270 ]
266271
267272
273+ # // grammar types
274+ # struct llama_grammar;
275+ llama_grammar_p = c_void_p
276+
277+ # // grammar element type
278+ # enum llama_gretype {
279+ # // end of rule definition
280+ # LLAMA_GRETYPE_END = 0,
281+
282+ # // start of alternate definition for rule
283+ # LLAMA_GRETYPE_ALT = 1,
284+
285+ # // non-terminal element: reference to rule
286+ # LLAMA_GRETYPE_RULE_REF = 2,
287+
288+ # // terminal element: character (code point)
289+ # LLAMA_GRETYPE_CHAR = 3,
290+
291+ # // inverse char(s) ([^a], [^a-b] [^abc])
292+ # LLAMA_GRETYPE_CHAR_NOT = 4,
293+
294+ # // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
295+ # // be an inclusive range ([a-z])
296+ # LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
297+
298+ # // modifies a preceding LLAMA_GRETYPE_CHAR or
299+ # // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
300+ # LLAMA_GRETYPE_CHAR_ALT = 6,
301+ # };
302+ LLAMA_GRETYPE_END = c_int (0 )
303+ LLAMA_GRETYPE_ALT = c_int (1 )
304+ LLAMA_GRETYPE_RULE_REF = c_int (2 )
305+ LLAMA_GRETYPE_CHAR = c_int (3 )
306+ LLAMA_GRETYPE_CHAR_NOT = c_int (4 )
307+ LLAMA_GRETYPE_CHAR_RNG_UPPER = c_int (5 )
308+ LLAMA_GRETYPE_CHAR_ALT = c_int (6 )
309+
310+
311+ # typedef struct llama_grammar_element {
312+ # enum llama_gretype type;
313+ # uint32_t value; // Unicode code point or rule ID
314+ # } llama_grammar_element;
315+ class llama_grammar_element (Structure ):
316+ _fields_ = [
317+ ("type" , c_int ),
318+ ("value" , c_uint32 ),
319+ ]
320+
321+
322+ llama_grammar_element_p = POINTER (llama_grammar_element )
323+
268324# // performance timing information
269325# struct llama_timings {
270326# double t_start_ms;
@@ -871,6 +927,37 @@ def llama_token_nl() -> int:
871927_lib .llama_token_nl .restype = llama_token
872928
873929
930+ # // Grammar
931+ # //
932+ # LLAMA_API struct llama_grammar * llama_grammar_init(
933+ # const llama_grammar_element ** rules,
934+ # size_t n_rules,
935+ # size_t start_rule_index);
936+ def llama_grammar_init (
937+ rules , # type: Array[llama_grammar_element_p] # type: ignore
938+ n_rules : c_size_t ,
939+ start_rule_index : c_size_t ,
940+ ) -> llama_grammar_p :
941+ return _lib .llama_grammar_init (rules , n_rules , start_rule_index )
942+
943+
944+ _lib .llama_grammar_init .argtypes = [
945+ POINTER (llama_grammar_element_p ),
946+ c_size_t ,
947+ c_size_t ,
948+ ]
949+ _lib .llama_grammar_init .restype = llama_grammar_p
950+
951+
952+ # LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
953+ def llama_grammar_free (grammar : llama_grammar_p ):
954+ return _lib .llama_grammar_free (grammar )
955+
956+
957+ _lib .llama_grammar_free .argtypes = [llama_grammar_p ]
958+ _lib .llama_grammar_free .restype = None
959+
960+
874961# Sampling functions
875962
876963
0 commit comments