11from __future__ import annotations
22
3- import os
43import ctypes
4+ import os
55import pathlib
66
7+ from ._ggml import (
8+ ggml_opt_get_optimizer_params
9+ )
10+
711from typing import (
812 Callable ,
913 Union ,
171175# llama_sampler_p = NewType("llama_sampler_p", int)
172176# llama_sampler_p_ctypes = ctypes.c_void_p
173177
178+ # struct llama_opt_params;
179+ llama_opt_params_p = NewType ("llama_opt_params_p" , int )
180+ llama_opt_params_p_ctypes = ctypes .c_void_p
181+
174182# struct llama_kv_cache;
175183llama_kv_cache_p = NewType ("llama_kv_cache_p" , int )
176184llama_kv_cache_p_ctypes = ctypes .c_void_p
243251# LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
244252# LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
245253# LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
254+ # LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
246255# };
247256LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0
248257LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1
279288LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32
280289LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33
281290LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34
291+ LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35
282292
283293
284294# // note: these values should be synchronized with ggml_rope
@@ -790,6 +800,7 @@ class llama_model_params(ctypes.Structure):
790800# bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
791801# bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
792802# bool no_perf; // whether to measure performance timings
803+ # bool op_offload; // whether to offload host tensor operations to device
793804# };
794805class llama_context_params (ctypes .Structure ):
795806 """Parameters for llama_context
@@ -811,7 +822,7 @@ class llama_context_params(ctypes.Structure):
811822 yarn_beta_fast (float): YaRN low correction dim
812823 yarn_beta_slow (float): YaRN high correction dim
813824 yarn_orig_ctx (int): YaRN original context size
814- defrag_thold (float): defragment the KV cache if holes/size > thold, < 0 disabled (default)
825+ defrag_thold (float): defragment the KV cache if holes/size > thold, <= 0 disabled (default)
815826 cb_eval (ggml_backend_sched_eval_callback): callback for scheduling eval
816827 cb_eval_user_data (ctypes.ctypes.c_void_p): user data for cb_eval
817828 type_k (int): data type for K cache
@@ -822,6 +833,7 @@ class llama_context_params(ctypes.Structure):
822833 offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU
823834 flash_attn (bool): whether to use flash attention
824835 no_perf (bool): whether to measure performance timings
836+ op_offload(bool): whether to offload host tensor operations to device
825837 """
826838
827839 if TYPE_CHECKING :
@@ -852,6 +864,7 @@ class llama_context_params(ctypes.Structure):
852864 offload_kqv : bool
853865 flash_attn : bool
854866 no_perf : bool
867+ op_offload :bool
855868
856869 _fields_ = [
857870 ("n_ctx" , ctypes .c_uint32 ),
@@ -881,6 +894,7 @@ class llama_context_params(ctypes.Structure):
881894 ("offload_kqv" , ctypes .c_bool ),
882895 ("flash_attn" , ctypes .c_bool ),
883896 ("no_perf" , ctypes .c_bool ),
897+ ("op_offload" , ctypes .c_bool ),
884898 ]
885899
886900
@@ -1193,7 +1207,20 @@ def llama_model_load_from_splits(
11931207 ...
11941208
11951209
1196- # LLAMA_API void llama_free_model(struct llama_model * model);
1210+ # LLAMA_API void llama_model_save_to_file(
1211+ # const struct llama_model * model,
1212+ # const char * path_model);
1213+ @ctypes_function (
1214+ "llama_model_save_to_file" ,
1215+ [llama_model_p_ctypes , ctypes .c_char_p ],
1216+ None ,
1217+ )
1218+ def llama_model_save_to_file (model : llama_model_p , path_model : bytes , / ):
1219+ ...
1220+
1221+
1222+ # DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
1223+ # "use llama_model_free instead");
11971224@ctypes_function (
11981225 "llama_free_model" ,
11991226 [llama_model_p_ctypes ],
@@ -4128,8 +4155,8 @@ def llama_sampler_get_seed(smpl: llama_sampler_p, /) -> int:
41284155 llama_token ,
41294156)
41304157def llama_sampler_sample (
4131- smpl : llama_sampler_p , ctx : llama_context_p , idx : int , /
4132- ) -> int :
4158+ smpl : llama_sampler_p , ctx : llama_context_p , idx : ctypes . c_int32 , /
4159+ ) -> ctypes . c_int32 :
41334160 ...
41344161
41354162
@@ -4306,3 +4333,85 @@ def llama_perf_sampler_reset(chain: llama_sampler_p, /):
43064333 ...
43074334
43084335
4336+ # //
4337+ # // training
4338+ # //
4339+
4340+ # // function that returns whether or not a given tensor contains trainable parameters
4341+ # typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata);
4342+ llama_opt_param_filter = ctypes .CFUNCTYPE (
4343+ ctypes .c_bool , ctypes .c_void_p , ctypes .c_void_p
4344+ )
4345+
4346+
4347+ # // always returns true
4348+ # LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata);
4349+ @ctypes_function ("llama_opt_param_filter_all" , [ctypes .c_void_p , ctypes .c_void_p ], ctypes .c_bool )
4350+ def llama_opt_param_filter_all (
4351+ tensor : llama_model_p ,
4352+ userdata : ctypes .c_void_p , /
4353+ ) -> bool :
4354+ ...
4355+
4356+ # struct llama_opt_params {
4357+ # uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0
4358+
4359+ # llama_opt_param_filter param_filter; // callback for determining which tensors contain trainable parameters
4360+ # void * param_filter_ud; // userdata for determining which tensors contain trainable parameters
4361+
4362+ # ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
4363+ # void * get_opt_pars_ud; // userdata for calculating optimizer parameters
4364+ # };
4365+ class llama_opt_params (ctypes .Structure ):
4366+ _fields_ = [
4367+ ("n_ctx_train" , ctypes .c_uint32 ),
4368+ ("param_filter" , llama_opt_param_filter ),
4369+ ("param_filter_ud" , ctypes .c_void_p ),
4370+ ("get_opt_pars" , ggml_opt_get_optimizer_params ),
4371+ ("get_opt_pars_ud" , ctypes .c_void_p ),
4372+ ]
4373+
4374+
4375+ # LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
4376+ @ctypes_function (
4377+ "llama_opt_init" ,
4378+ [llama_context_p_ctypes , llama_model_p_ctypes , llama_opt_params_p_ctypes ],
4379+ None ,
4380+ )
4381+ def llama_opt_init (
4382+ lctx : llama_context_p ,
4383+ model : llama_model_p ,
4384+ lopt_params : llama_opt_params_p , /
4385+ ):
4386+ ...
4387+
4388+ # LLAMA_API void llama_opt_epoch(
4389+ # struct llama_context * lctx,
4390+ # ggml_opt_dataset_t dataset,
4391+ # ggml_opt_result_t result_train,
4392+ # ggml_opt_result_t result_eval,
4393+ # int64_t idata_split,
4394+ # ggml_opt_epoch_callback callback_train,
4395+ # ggml_opt_epoch_callback callback_eval);
4396+ @ctypes_function (
4397+ "llama_opt_epoch" ,[
4398+ llama_context_p_ctypes ,
4399+ ctypes .c_void_p ,
4400+ ctypes .c_void_p ,
4401+ ctypes .c_void_p ,
4402+ ctypes .c_int64 ,
4403+ ctypes .c_void_p ,
4404+ ctypes .c_void_p
4405+ ],
4406+ None ,
4407+ )
4408+ def llama_opt_epoch (
4409+ lctx : llama_context_p ,
4410+ dataset : ctypes .c_void_p ,
4411+ result_train : ctypes .c_void_p ,
4412+ result_eval : ctypes .c_void_p ,
4413+ idata_split : ctypes .c_int64 ,
4414+ callback_train : ctypes .c_void_p ,
4415+ callback_eval : ctypes .c_void_p , /
4416+ ):
4417+ ...
0 commit comments