@@ -87,29 +87,29 @@ def _load_shared_library(lib_base_name: str):
8787# llama.h bindings
8888
8989GGML_USE_CUBLAS = hasattr (_lib , "ggml_init_cublas" )
90- GGML_CUDA_MAX_DEVICES = ctypes . c_int ( 16 )
91- LLAMA_MAX_DEVICES = GGML_CUDA_MAX_DEVICES if GGML_USE_CUBLAS else ctypes . c_int ( 1 )
90+ GGML_CUDA_MAX_DEVICES = 16
91+ LLAMA_MAX_DEVICES = GGML_CUDA_MAX_DEVICES if GGML_USE_CUBLAS else 1
9292
9393# #define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt'
94- LLAMA_FILE_MAGIC_GGJT = ctypes . c_uint ( 0x67676A74 )
94+ LLAMA_FILE_MAGIC_GGJT = 0x67676A74
9595# #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
96- LLAMA_FILE_MAGIC_GGLA = ctypes . c_uint ( 0x67676C61 )
96+ LLAMA_FILE_MAGIC_GGLA = 0x67676C61
9797# #define LLAMA_FILE_MAGIC_GGMF 0x67676d66u // 'ggmf'
98- LLAMA_FILE_MAGIC_GGMF = ctypes . c_uint ( 0x67676D66 )
98+ LLAMA_FILE_MAGIC_GGMF = 0x67676D66
9999# #define LLAMA_FILE_MAGIC_GGML 0x67676d6cu // 'ggml'
100- LLAMA_FILE_MAGIC_GGML = ctypes . c_uint ( 0x67676D6C )
100+ LLAMA_FILE_MAGIC_GGML = 0x67676D6C
101101# #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
102- LLAMA_FILE_MAGIC_GGSN = ctypes . c_uint ( 0x6767736E )
102+ LLAMA_FILE_MAGIC_GGSN = 0x6767736E
103103
104104# #define LLAMA_FILE_VERSION 3
105- LLAMA_FILE_VERSION = c_int ( 3 )
105+ LLAMA_FILE_VERSION = 3
106106LLAMA_FILE_MAGIC = LLAMA_FILE_MAGIC_GGJT
107107LLAMA_FILE_MAGIC_UNVERSIONED = LLAMA_FILE_MAGIC_GGML
108108LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN
109- LLAMA_SESSION_VERSION = c_int ( 1 )
109+ LLAMA_SESSION_VERSION = 1
110110
111111# #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
112- LLAMA_DEFAULT_SEED = c_int ( 0xFFFFFFFF )
112+ LLAMA_DEFAULT_SEED = 0xFFFFFFFF
113113
114114# struct llama_model;
115115llama_model_p = c_void_p
@@ -235,23 +235,23 @@ class llama_context_params(Structure):
235235# LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors
236236# LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
237237# };
238- LLAMA_FTYPE_ALL_F32 = c_int ( 0 )
239- LLAMA_FTYPE_MOSTLY_F16 = c_int ( 1 )
240- LLAMA_FTYPE_MOSTLY_Q4_0 = c_int ( 2 )
241- LLAMA_FTYPE_MOSTLY_Q4_1 = c_int ( 3 )
242- LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = c_int ( 4 )
243- LLAMA_FTYPE_MOSTLY_Q8_0 = c_int ( 7 )
244- LLAMA_FTYPE_MOSTLY_Q5_0 = c_int ( 8 )
245- LLAMA_FTYPE_MOSTLY_Q5_1 = c_int ( 9 )
246- LLAMA_FTYPE_MOSTLY_Q2_K = c_int ( 10 )
247- LLAMA_FTYPE_MOSTLY_Q3_K_S = c_int ( 11 )
248- LLAMA_FTYPE_MOSTLY_Q3_K_M = c_int ( 12 )
249- LLAMA_FTYPE_MOSTLY_Q3_K_L = c_int ( 13 )
250- LLAMA_FTYPE_MOSTLY_Q4_K_S = c_int ( 14 )
251- LLAMA_FTYPE_MOSTLY_Q4_K_M = c_int ( 15 )
252- LLAMA_FTYPE_MOSTLY_Q5_K_S = c_int ( 16 )
253- LLAMA_FTYPE_MOSTLY_Q5_K_M = c_int ( 17 )
254- LLAMA_FTYPE_MOSTLY_Q6_K = c_int ( 18 )
238+ LLAMA_FTYPE_ALL_F32 = 0
239+ LLAMA_FTYPE_MOSTLY_F16 = 1
240+ LLAMA_FTYPE_MOSTLY_Q4_0 = 2
241+ LLAMA_FTYPE_MOSTLY_Q4_1 = 3
242+ LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4
243+ LLAMA_FTYPE_MOSTLY_Q8_0 = 7
244+ LLAMA_FTYPE_MOSTLY_Q5_0 = 8
245+ LLAMA_FTYPE_MOSTLY_Q5_1 = 9
246+ LLAMA_FTYPE_MOSTLY_Q2_K = 10
247+ LLAMA_FTYPE_MOSTLY_Q3_K_S = 11
248+ LLAMA_FTYPE_MOSTLY_Q3_K_M = 12
249+ LLAMA_FTYPE_MOSTLY_Q3_K_L = 13
250+ LLAMA_FTYPE_MOSTLY_Q4_K_S = 14
251+ LLAMA_FTYPE_MOSTLY_Q4_K_M = 15
252+ LLAMA_FTYPE_MOSTLY_Q5_K_S = 16
253+ LLAMA_FTYPE_MOSTLY_Q5_K_M = 17
254+ LLAMA_FTYPE_MOSTLY_Q6_K = 18
255255
256256
257257# // model quantization parameters
@@ -299,13 +299,13 @@ class llama_model_quantize_params(Structure):
299299# // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
300300# LLAMA_GRETYPE_CHAR_ALT = 6,
301301# };
302- LLAMA_GRETYPE_END = c_int ( 0 )
303- LLAMA_GRETYPE_ALT = c_int ( 1 )
304- LLAMA_GRETYPE_RULE_REF = c_int ( 2 )
305- LLAMA_GRETYPE_CHAR = c_int ( 3 )
306- LLAMA_GRETYPE_CHAR_NOT = c_int ( 4 )
307- LLAMA_GRETYPE_CHAR_RNG_UPPER = c_int ( 5 )
308- LLAMA_GRETYPE_CHAR_ALT = c_int ( 6 )
302+ LLAMA_GRETYPE_END = 0
303+ LLAMA_GRETYPE_ALT = 1
304+ LLAMA_GRETYPE_RULE_REF = 2
305+ LLAMA_GRETYPE_CHAR = 3
306+ LLAMA_GRETYPE_CHAR_NOT = 4
307+ LLAMA_GRETYPE_CHAR_RNG_UPPER = 5
308+ LLAMA_GRETYPE_CHAR_ALT = 6
309309
310310
311311# typedef struct llama_grammar_element {
@@ -399,7 +399,7 @@ def llama_mlock_supported() -> bool:
399399# // If numa is true, use NUMA optimizations
400400# // Call once at the start of the program
401401# LLAMA_API void llama_backend_init(bool numa);
402- def llama_backend_init (numa : c_bool ):
402+ def llama_backend_init (numa : Union [ c_bool , bool ] ):
403403 return _lib .llama_backend_init (numa )
404404
405405
@@ -521,9 +521,9 @@ def llama_model_quantize(
521521# int n_threads);
522522def llama_apply_lora_from_file (
523523 ctx : llama_context_p ,
524- path_lora : c_char_p ,
525- path_base_model : c_char_p ,
526- n_threads : c_int ,
524+ path_lora : Union [ c_char_p , bytes ] ,
525+ path_base_model : Union [ c_char_p , bytes ] ,
526+ n_threads : Union [ c_int , int ] ,
527527) -> int :
528528 return _lib .llama_apply_lora_from_file (ctx , path_lora , path_base_model , n_threads )
529529
@@ -541,7 +541,7 @@ def llama_model_apply_lora_from_file(
541541 model : llama_model_p ,
542542 path_lora : Union [c_char_p , bytes ],
543543 path_base_model : Union [c_char_p , bytes ],
544- n_threads : c_int ,
544+ n_threads : Union [ c_int , int ] ,
545545) -> int :
546546 return _lib .llama_model_apply_lora_from_file (
547547 model , path_lora , path_base_model , n_threads
@@ -621,7 +621,7 @@ def llama_load_session_file(
621621 ctx : llama_context_p ,
622622 path_session : bytes ,
623623 tokens_out , # type: Array[llama_token]
624- n_token_capacity : c_size_t ,
624+ n_token_capacity : Union [ c_size_t , int ] ,
625625 n_token_count_out , # type: _Pointer[c_size_t]
626626) -> int :
627627 return _lib .llama_load_session_file (
@@ -644,7 +644,7 @@ def llama_save_session_file(
644644 ctx : llama_context_p ,
645645 path_session : bytes ,
646646 tokens , # type: Array[llama_token]
647- n_token_count : c_size_t ,
647+ n_token_count : Union [ c_size_t , int ] ,
648648) -> int :
649649 return _lib .llama_save_session_file (ctx , path_session , tokens , n_token_count )
650650
@@ -671,9 +671,9 @@ def llama_save_session_file(
671671def llama_eval (
672672 ctx : llama_context_p ,
673673 tokens , # type: Array[llama_token]
674- n_tokens : c_int ,
675- n_past : c_int ,
676- n_threads : c_int ,
674+ n_tokens : Union [ c_int , int ] ,
675+ n_past : Union [ c_int , int ] ,
676+ n_threads : Union [ c_int , int ] ,
677677) -> int :
678678 return _lib .llama_eval (ctx , tokens , n_tokens , n_past , n_threads )
679679
@@ -692,9 +692,9 @@ def llama_eval(
692692def llama_eval_embd (
693693 ctx : llama_context_p ,
694694 embd , # type: Array[c_float]
695- n_tokens : c_int ,
696- n_past : c_int ,
697- n_threads : c_int ,
695+ n_tokens : Union [ c_int , int ] ,
696+ n_past : Union [ c_int , int ] ,
697+ n_threads : Union [ c_int , int ] ,
698698) -> int :
699699 return _lib .llama_eval_embd (ctx , embd , n_tokens , n_past , n_threads )
700700
@@ -718,8 +718,8 @@ def llama_tokenize(
718718 ctx : llama_context_p ,
719719 text : bytes ,
720720 tokens , # type: Array[llama_token]
721- n_max_tokens : c_int ,
722- add_bos : c_bool ,
721+ n_max_tokens : Union [ c_int , int ] ,
722+ add_bos : Union [ c_bool , bool ] ,
723723) -> int :
724724 return _lib .llama_tokenize (ctx , text , tokens , n_max_tokens , add_bos )
725725
@@ -738,8 +738,8 @@ def llama_tokenize_with_model(
738738 model : llama_model_p ,
739739 text : bytes ,
740740 tokens , # type: Array[llama_token]
741- n_max_tokens : c_int ,
742- add_bos : c_bool ,
741+ n_max_tokens : Union [ c_int , int ] ,
742+ add_bos : Union [ c_bool , bool ] ,
743743) -> int :
744744 return _lib .llama_tokenize_with_model (model , text , tokens , n_max_tokens , add_bos )
745745
@@ -809,7 +809,7 @@ def llama_get_vocab(
809809 ctx : llama_context_p ,
810810 strings , # type: Array[c_char_p] # type: ignore
811811 scores , # type: Array[c_float] # type: ignore
812- capacity : c_int ,
812+ capacity : Union [ c_int , int ] ,
813813) -> int :
814814 return _lib .llama_get_vocab (ctx , strings , scores , capacity )
815815
@@ -832,7 +832,7 @@ def llama_get_vocab_from_model(
832832 model : llama_model_p ,
833833 strings , # type: Array[c_char_p] # type: ignore
834834 scores , # type: Array[c_float] # type: ignore
835- capacity : c_int ,
835+ capacity : Union [ c_int , int ] ,
836836) -> int :
837837 return _lib .llama_get_vocab_from_model (model , strings , scores , capacity )
838838
@@ -935,8 +935,8 @@ def llama_token_nl() -> int:
935935# size_t start_rule_index);
936936def llama_grammar_init (
937937 rules , # type: Array[llama_grammar_element_p] # type: ignore
938- n_rules : c_size_t ,
939- start_rule_index : c_size_t ,
938+ n_rules : Union [ c_size_t , int ] ,
939+ start_rule_index : Union [ c_size_t , int ] ,
940940) -> llama_grammar_p :
941941 return _lib .llama_grammar_init (rules , n_rules , start_rule_index )
942942
@@ -967,8 +967,8 @@ def llama_sample_repetition_penalty(
967967 ctx : llama_context_p ,
968968 candidates , # type: _Pointer[llama_token_data_array]
969969 last_tokens_data , # type: Array[llama_token]
970- last_tokens_size : c_int ,
971- penalty : c_float ,
970+ last_tokens_size : Union [ c_int , int ] ,
971+ penalty : Union [ c_float , float ] ,
972972):
973973 return _lib .llama_sample_repetition_penalty (
974974 ctx , candidates , last_tokens_data , last_tokens_size , penalty
@@ -991,9 +991,9 @@ def llama_sample_frequency_and_presence_penalties(
991991 ctx : llama_context_p ,
992992 candidates , # type: _Pointer[llama_token_data_array]
993993 last_tokens_data , # type: Array[llama_token]
994- last_tokens_size : c_int ,
995- alpha_frequency : c_float ,
996- alpha_presence : c_float ,
994+ last_tokens_size : Union [ c_int , int ] ,
995+ alpha_frequency : Union [ c_float , float ] ,
996+ alpha_presence : Union [ c_float , float ] ,
997997):
998998 return _lib .llama_sample_frequency_and_presence_penalties (
999999 ctx ,
@@ -1029,7 +1029,7 @@ def llama_sample_classifier_free_guidance(
10291029 ctx : llama_context_p ,
10301030 candidates , # type: _Pointer[llama_token_data_array]
10311031 guidance_ctx : llama_context_p ,
1032- scale : c_float ,
1032+ scale : Union [ c_float , float ] ,
10331033):
10341034 return _lib .llama_sample_classifier_free_guidance (
10351035 ctx , candidates , guidance_ctx , scale
@@ -1065,8 +1065,8 @@ def llama_sample_softmax(
10651065def llama_sample_top_k (
10661066 ctx : llama_context_p ,
10671067 candidates , # type: _Pointer[llama_token_data_array]
1068- k : c_int ,
1069- min_keep : c_size_t ,
1068+ k : Union [ c_int , int ] ,
1069+ min_keep : Union [ c_size_t , int ] ,
10701070):
10711071 return _lib .llama_sample_top_k (ctx , candidates , k , min_keep )
10721072
@@ -1085,8 +1085,8 @@ def llama_sample_top_k(
10851085def llama_sample_top_p (
10861086 ctx : llama_context_p ,
10871087 candidates , # type: _Pointer[llama_token_data_array]
1088- p : c_float ,
1089- min_keep : c_size_t ,
1088+ p : Union [ c_float , float ] ,
1089+ min_keep : Union [ c_size_t , int ] ,
10901090):
10911091 return _lib .llama_sample_top_p (ctx , candidates , p , min_keep )
10921092
@@ -1105,8 +1105,8 @@ def llama_sample_top_p(
11051105def llama_sample_tail_free (
11061106 ctx : llama_context_p ,
11071107 candidates , # type: _Pointer[llama_token_data_array]
1108- z : c_float ,
1109- min_keep : c_size_t ,
1108+ z : Union [ c_float , float ] ,
1109+ min_keep : Union [ c_size_t , int ] ,
11101110):
11111111 return _lib .llama_sample_tail_free (ctx , candidates , z , min_keep )
11121112
@@ -1125,8 +1125,8 @@ def llama_sample_tail_free(
11251125def llama_sample_typical (
11261126 ctx : llama_context_p ,
11271127 candidates , # type: _Pointer[llama_token_data_array]
1128- p : c_float ,
1129- min_keep : c_size_t ,
1128+ p : Union [ c_float , float ] ,
1129+ min_keep : Union [ c_size_t , int ] ,
11301130):
11311131 return _lib .llama_sample_typical (ctx , candidates , p , min_keep )
11321132
@@ -1144,7 +1144,7 @@ def llama_sample_typical(
11441144def llama_sample_temperature (
11451145 ctx : llama_context_p ,
11461146 candidates , # type: _Pointer[llama_token_data_array]
1147- temp : c_float ,
1147+ temp : Union [ c_float , float ] ,
11481148):
11491149 return _lib .llama_sample_temperature (ctx , candidates , temp )
11501150
@@ -1167,9 +1167,9 @@ def llama_sample_temperature(
11671167def llama_sample_token_mirostat (
11681168 ctx : llama_context_p ,
11691169 candidates , # type: _Pointer[llama_token_data_array]
1170- tau : c_float ,
1171- eta : c_float ,
1172- m : c_int ,
1170+ tau : Union [ c_float , float ] ,
1171+ eta : Union [ c_float , float ] ,
1172+ m : Union [ c_int , int ] ,
11731173 mu , # type: _Pointer[c_float]
11741174) -> int :
11751175 return _lib .llama_sample_token_mirostat (ctx , candidates , tau , eta , m , mu )
@@ -1195,8 +1195,8 @@ def llama_sample_token_mirostat(
11951195def llama_sample_token_mirostat_v2 (
11961196 ctx : llama_context_p ,
11971197 candidates , # type: _Pointer[llama_token_data_array]
1198- tau : c_float ,
1199- eta : c_float ,
1198+ tau : Union [ c_float , float ] ,
1199+ eta : Union [ c_float , float ] ,
12001200 mu , # type: _Pointer[c_float]
12011201) -> int :
12021202 return _lib .llama_sample_token_mirostat_v2 (ctx , candidates , tau , eta , mu )
@@ -1289,5 +1289,5 @@ def llama_print_system_info() -> bytes:
12891289_llama_initialized = False
12901290
12911291if not _llama_initialized :
1292- llama_backend_init (c_bool ( False ) )
1292+ llama_backend_init (False )
12931293 _llama_initialized = True
0 commit comments