@@ -234,6 +234,22 @@ class llama_context_params(Structure):
234234LLAMA_FTYPE_MOSTLY_Q6_K = c_int (18 )
235235
236236
237+ # // model quantization parameters
238+ # typedef struct llama_model_quantize_params {
239+ # int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
240+ # enum llama_ftype ftype; // quantize to this llama_ftype
241+ # bool allow_requantize; // allow quantizing non-f32/f16 tensors
242+ # bool quantize_output_tensor; // quantize output.weight
243+ # } llama_model_quantize_params;
244+ class llama_model_quantize_params (Structure ):
245+ _fields_ = [
246+ ("nthread" , c_int ),
247+ ("ftype" , c_int ),
248+ ("allow_requantize" , c_bool ),
249+ ("quantize_output_tensor" , c_bool ),
250+ ]
251+
252+
237253# LLAMA_API struct llama_context_params llama_context_default_params();
238254def llama_context_default_params () -> llama_context_params :
239255 return _lib .llama_context_default_params ()
@@ -243,6 +259,15 @@ def llama_context_default_params() -> llama_context_params:
243259_lib .llama_context_default_params .restype = llama_context_params
244260
245261
262+ # LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params();
263+ def llama_model_quantize_default_params () -> llama_model_quantize_params :
264+ return _lib .llama_model_quantize_default_params ()
265+
266+
267+ _lib .llama_model_quantize_default_params .argtypes = []
268+ _lib .llama_model_quantize_default_params .restype = llama_model_quantize_params
269+
270+
246271# LLAMA_API bool llama_mmap_supported();
247272def llama_mmap_supported () -> bool :
248273 return _lib .llama_mmap_supported ()
@@ -308,21 +333,24 @@ def llama_free(ctx: llama_context_p):
308333_lib .llama_free .restype = None
309334
310335
311- # TODO: not great API - very likely to change
312- # Returns 0 on success
313- # nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
336+ # // Returns 0 on success
314337# LLAMA_API int llama_model_quantize(
315338# const char * fname_inp,
316339# const char * fname_out,
317- # enum llama_ftype ftype,
318- # int nthread);
340+ # const llama_model_quantize_params * params);
319341def llama_model_quantize (
320- fname_inp : bytes , fname_out : bytes , ftype : c_int , nthread : c_int
342+ fname_inp : bytes ,
343+ fname_out : bytes ,
344+ params , # type: POINTER(llama_model_quantize_params) # type: ignore
321345) -> int :
322- return _lib .llama_model_quantize (fname_inp , fname_out , ftype , nthread )
346+ return _lib .llama_model_quantize (fname_inp , fname_out , params )
323347
324348
325- _lib .llama_model_quantize .argtypes = [c_char_p , c_char_p , c_int , c_int ]
349+ _lib .llama_model_quantize .argtypes = [
350+ c_char_p ,
351+ c_char_p ,
352+ POINTER (llama_model_quantize_params ),
353+ ]
326354_lib .llama_model_quantize .restype = c_int
327355
328356
0 commit comments