Skip to content

Commit 1b6997d

Browse files
committed
Convert constants to python types and allow python types in low-level api
1 parent 3434803 commit 1b6997d

File tree

1 file changed

+76
-76
lines changed

1 file changed

+76
-76
lines changed

llama_cpp/llama_cpp.py

Lines changed: 76 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -87,29 +87,29 @@ def _load_shared_library(lib_base_name: str):
8787
# llama.h bindings
8888

8989
GGML_USE_CUBLAS = hasattr(_lib, "ggml_init_cublas")
90-
GGML_CUDA_MAX_DEVICES = ctypes.c_int(16)
91-
LLAMA_MAX_DEVICES = GGML_CUDA_MAX_DEVICES if GGML_USE_CUBLAS else ctypes.c_int(1)
90+
GGML_CUDA_MAX_DEVICES = 16
91+
LLAMA_MAX_DEVICES = GGML_CUDA_MAX_DEVICES if GGML_USE_CUBLAS else 1
9292

9393
# #define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt'
94-
LLAMA_FILE_MAGIC_GGJT = ctypes.c_uint(0x67676A74)
94+
LLAMA_FILE_MAGIC_GGJT = 0x67676A74
9595
# #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
96-
LLAMA_FILE_MAGIC_GGLA = ctypes.c_uint(0x67676C61)
96+
LLAMA_FILE_MAGIC_GGLA = 0x67676C61
9797
# #define LLAMA_FILE_MAGIC_GGMF 0x67676d66u // 'ggmf'
98-
LLAMA_FILE_MAGIC_GGMF = ctypes.c_uint(0x67676D66)
98+
LLAMA_FILE_MAGIC_GGMF = 0x67676D66
9999
# #define LLAMA_FILE_MAGIC_GGML 0x67676d6cu // 'ggml'
100-
LLAMA_FILE_MAGIC_GGML = ctypes.c_uint(0x67676D6C)
100+
LLAMA_FILE_MAGIC_GGML = 0x67676D6C
101101
# #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
102-
LLAMA_FILE_MAGIC_GGSN = ctypes.c_uint(0x6767736E)
102+
LLAMA_FILE_MAGIC_GGSN = 0x6767736E
103103

104104
# #define LLAMA_FILE_VERSION 3
105-
LLAMA_FILE_VERSION = c_int(3)
105+
LLAMA_FILE_VERSION = 3
106106
LLAMA_FILE_MAGIC = LLAMA_FILE_MAGIC_GGJT
107107
LLAMA_FILE_MAGIC_UNVERSIONED = LLAMA_FILE_MAGIC_GGML
108108
LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN
109-
LLAMA_SESSION_VERSION = c_int(1)
109+
LLAMA_SESSION_VERSION = 1
110110

111111
# #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
112-
LLAMA_DEFAULT_SEED = c_int(0xFFFFFFFF)
112+
LLAMA_DEFAULT_SEED = 0xFFFFFFFF
113113

114114
# struct llama_model;
115115
llama_model_p = c_void_p
@@ -235,23 +235,23 @@ class llama_context_params(Structure):
235235
# LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors
236236
# LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
237237
# };
238-
LLAMA_FTYPE_ALL_F32 = c_int(0)
239-
LLAMA_FTYPE_MOSTLY_F16 = c_int(1)
240-
LLAMA_FTYPE_MOSTLY_Q4_0 = c_int(2)
241-
LLAMA_FTYPE_MOSTLY_Q4_1 = c_int(3)
242-
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = c_int(4)
243-
LLAMA_FTYPE_MOSTLY_Q8_0 = c_int(7)
244-
LLAMA_FTYPE_MOSTLY_Q5_0 = c_int(8)
245-
LLAMA_FTYPE_MOSTLY_Q5_1 = c_int(9)
246-
LLAMA_FTYPE_MOSTLY_Q2_K = c_int(10)
247-
LLAMA_FTYPE_MOSTLY_Q3_K_S = c_int(11)
248-
LLAMA_FTYPE_MOSTLY_Q3_K_M = c_int(12)
249-
LLAMA_FTYPE_MOSTLY_Q3_K_L = c_int(13)
250-
LLAMA_FTYPE_MOSTLY_Q4_K_S = c_int(14)
251-
LLAMA_FTYPE_MOSTLY_Q4_K_M = c_int(15)
252-
LLAMA_FTYPE_MOSTLY_Q5_K_S = c_int(16)
253-
LLAMA_FTYPE_MOSTLY_Q5_K_M = c_int(17)
254-
LLAMA_FTYPE_MOSTLY_Q6_K = c_int(18)
238+
LLAMA_FTYPE_ALL_F32 = 0
239+
LLAMA_FTYPE_MOSTLY_F16 = 1
240+
LLAMA_FTYPE_MOSTLY_Q4_0 = 2
241+
LLAMA_FTYPE_MOSTLY_Q4_1 = 3
242+
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4
243+
LLAMA_FTYPE_MOSTLY_Q8_0 = 7
244+
LLAMA_FTYPE_MOSTLY_Q5_0 = 8
245+
LLAMA_FTYPE_MOSTLY_Q5_1 = 9
246+
LLAMA_FTYPE_MOSTLY_Q2_K = 10
247+
LLAMA_FTYPE_MOSTLY_Q3_K_S = 11
248+
LLAMA_FTYPE_MOSTLY_Q3_K_M = 12
249+
LLAMA_FTYPE_MOSTLY_Q3_K_L = 13
250+
LLAMA_FTYPE_MOSTLY_Q4_K_S = 14
251+
LLAMA_FTYPE_MOSTLY_Q4_K_M = 15
252+
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16
253+
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17
254+
LLAMA_FTYPE_MOSTLY_Q6_K = 18
255255

256256

257257
# // model quantization parameters
@@ -299,13 +299,13 @@ class llama_model_quantize_params(Structure):
299299
# // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
300300
# LLAMA_GRETYPE_CHAR_ALT = 6,
301301
# };
302-
LLAMA_GRETYPE_END = c_int(0)
303-
LLAMA_GRETYPE_ALT = c_int(1)
304-
LLAMA_GRETYPE_RULE_REF = c_int(2)
305-
LLAMA_GRETYPE_CHAR = c_int(3)
306-
LLAMA_GRETYPE_CHAR_NOT = c_int(4)
307-
LLAMA_GRETYPE_CHAR_RNG_UPPER = c_int(5)
308-
LLAMA_GRETYPE_CHAR_ALT = c_int(6)
302+
LLAMA_GRETYPE_END = 0
303+
LLAMA_GRETYPE_ALT = 1
304+
LLAMA_GRETYPE_RULE_REF = 2
305+
LLAMA_GRETYPE_CHAR = 3
306+
LLAMA_GRETYPE_CHAR_NOT = 4
307+
LLAMA_GRETYPE_CHAR_RNG_UPPER = 5
308+
LLAMA_GRETYPE_CHAR_ALT = 6
309309

310310

311311
# typedef struct llama_grammar_element {
@@ -399,7 +399,7 @@ def llama_mlock_supported() -> bool:
399399
# // If numa is true, use NUMA optimizations
400400
# // Call once at the start of the program
401401
# LLAMA_API void llama_backend_init(bool numa);
402-
def llama_backend_init(numa: c_bool):
402+
def llama_backend_init(numa: Union[c_bool, bool]):
403403
return _lib.llama_backend_init(numa)
404404

405405

@@ -521,9 +521,9 @@ def llama_model_quantize(
521521
# int n_threads);
522522
def llama_apply_lora_from_file(
523523
ctx: llama_context_p,
524-
path_lora: c_char_p,
525-
path_base_model: c_char_p,
526-
n_threads: c_int,
524+
path_lora: Union[c_char_p, bytes],
525+
path_base_model: Union[c_char_p, bytes],
526+
n_threads: Union[c_int, int],
527527
) -> int:
528528
return _lib.llama_apply_lora_from_file(ctx, path_lora, path_base_model, n_threads)
529529

@@ -541,7 +541,7 @@ def llama_model_apply_lora_from_file(
541541
model: llama_model_p,
542542
path_lora: Union[c_char_p, bytes],
543543
path_base_model: Union[c_char_p, bytes],
544-
n_threads: c_int,
544+
n_threads: Union[c_int, int],
545545
) -> int:
546546
return _lib.llama_model_apply_lora_from_file(
547547
model, path_lora, path_base_model, n_threads
@@ -621,7 +621,7 @@ def llama_load_session_file(
621621
ctx: llama_context_p,
622622
path_session: bytes,
623623
tokens_out, # type: Array[llama_token]
624-
n_token_capacity: c_size_t,
624+
n_token_capacity: Union[c_size_t, int],
625625
n_token_count_out, # type: _Pointer[c_size_t]
626626
) -> int:
627627
return _lib.llama_load_session_file(
@@ -644,7 +644,7 @@ def llama_save_session_file(
644644
ctx: llama_context_p,
645645
path_session: bytes,
646646
tokens, # type: Array[llama_token]
647-
n_token_count: c_size_t,
647+
n_token_count: Union[c_size_t, int],
648648
) -> int:
649649
return _lib.llama_save_session_file(ctx, path_session, tokens, n_token_count)
650650

@@ -671,9 +671,9 @@ def llama_save_session_file(
671671
def llama_eval(
672672
ctx: llama_context_p,
673673
tokens, # type: Array[llama_token]
674-
n_tokens: c_int,
675-
n_past: c_int,
676-
n_threads: c_int,
674+
n_tokens: Union[c_int, int],
675+
n_past: Union[c_int, int],
676+
n_threads: Union[c_int, int],
677677
) -> int:
678678
return _lib.llama_eval(ctx, tokens, n_tokens, n_past, n_threads)
679679

@@ -692,9 +692,9 @@ def llama_eval(
692692
def llama_eval_embd(
693693
ctx: llama_context_p,
694694
embd, # type: Array[c_float]
695-
n_tokens: c_int,
696-
n_past: c_int,
697-
n_threads: c_int,
695+
n_tokens: Union[c_int, int],
696+
n_past: Union[c_int, int],
697+
n_threads: Union[c_int, int],
698698
) -> int:
699699
return _lib.llama_eval_embd(ctx, embd, n_tokens, n_past, n_threads)
700700

@@ -718,8 +718,8 @@ def llama_tokenize(
718718
ctx: llama_context_p,
719719
text: bytes,
720720
tokens, # type: Array[llama_token]
721-
n_max_tokens: c_int,
722-
add_bos: c_bool,
721+
n_max_tokens: Union[c_int, int],
722+
add_bos: Union[c_bool, bool],
723723
) -> int:
724724
return _lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos)
725725

@@ -738,8 +738,8 @@ def llama_tokenize_with_model(
738738
model: llama_model_p,
739739
text: bytes,
740740
tokens, # type: Array[llama_token]
741-
n_max_tokens: c_int,
742-
add_bos: c_bool,
741+
n_max_tokens: Union[c_int, int],
742+
add_bos: Union[c_bool, bool],
743743
) -> int:
744744
return _lib.llama_tokenize_with_model(model, text, tokens, n_max_tokens, add_bos)
745745

@@ -809,7 +809,7 @@ def llama_get_vocab(
809809
ctx: llama_context_p,
810810
strings, # type: Array[c_char_p] # type: ignore
811811
scores, # type: Array[c_float] # type: ignore
812-
capacity: c_int,
812+
capacity: Union[c_int, int],
813813
) -> int:
814814
return _lib.llama_get_vocab(ctx, strings, scores, capacity)
815815

@@ -832,7 +832,7 @@ def llama_get_vocab_from_model(
832832
model: llama_model_p,
833833
strings, # type: Array[c_char_p] # type: ignore
834834
scores, # type: Array[c_float] # type: ignore
835-
capacity: c_int,
835+
capacity: Union[c_int, int],
836836
) -> int:
837837
return _lib.llama_get_vocab_from_model(model, strings, scores, capacity)
838838

@@ -935,8 +935,8 @@ def llama_token_nl() -> int:
935935
# size_t start_rule_index);
936936
def llama_grammar_init(
937937
rules, # type: Array[llama_grammar_element_p] # type: ignore
938-
n_rules: c_size_t,
939-
start_rule_index: c_size_t,
938+
n_rules: Union[c_size_t, int],
939+
start_rule_index: Union[c_size_t, int],
940940
) -> llama_grammar_p:
941941
return _lib.llama_grammar_init(rules, n_rules, start_rule_index)
942942

@@ -967,8 +967,8 @@ def llama_sample_repetition_penalty(
967967
ctx: llama_context_p,
968968
candidates, # type: _Pointer[llama_token_data_array]
969969
last_tokens_data, # type: Array[llama_token]
970-
last_tokens_size: c_int,
971-
penalty: c_float,
970+
last_tokens_size: Union[c_int, int],
971+
penalty: Union[c_float, float],
972972
):
973973
return _lib.llama_sample_repetition_penalty(
974974
ctx, candidates, last_tokens_data, last_tokens_size, penalty
@@ -991,9 +991,9 @@ def llama_sample_frequency_and_presence_penalties(
991991
ctx: llama_context_p,
992992
candidates, # type: _Pointer[llama_token_data_array]
993993
last_tokens_data, # type: Array[llama_token]
994-
last_tokens_size: c_int,
995-
alpha_frequency: c_float,
996-
alpha_presence: c_float,
994+
last_tokens_size: Union[c_int, int],
995+
alpha_frequency: Union[c_float, float],
996+
alpha_presence: Union[c_float, float],
997997
):
998998
return _lib.llama_sample_frequency_and_presence_penalties(
999999
ctx,
@@ -1029,7 +1029,7 @@ def llama_sample_classifier_free_guidance(
10291029
ctx: llama_context_p,
10301030
candidates, # type: _Pointer[llama_token_data_array]
10311031
guidance_ctx: llama_context_p,
1032-
scale: c_float,
1032+
scale: Union[c_float, float],
10331033
):
10341034
return _lib.llama_sample_classifier_free_guidance(
10351035
ctx, candidates, guidance_ctx, scale
@@ -1065,8 +1065,8 @@ def llama_sample_softmax(
10651065
def llama_sample_top_k(
10661066
ctx: llama_context_p,
10671067
candidates, # type: _Pointer[llama_token_data_array]
1068-
k: c_int,
1069-
min_keep: c_size_t,
1068+
k: Union[c_int, int],
1069+
min_keep: Union[c_size_t, int],
10701070
):
10711071
return _lib.llama_sample_top_k(ctx, candidates, k, min_keep)
10721072

@@ -1085,8 +1085,8 @@ def llama_sample_top_k(
10851085
def llama_sample_top_p(
10861086
ctx: llama_context_p,
10871087
candidates, # type: _Pointer[llama_token_data_array]
1088-
p: c_float,
1089-
min_keep: c_size_t,
1088+
p: Union[c_float, float],
1089+
min_keep: Union[c_size_t, int],
10901090
):
10911091
return _lib.llama_sample_top_p(ctx, candidates, p, min_keep)
10921092

@@ -1105,8 +1105,8 @@ def llama_sample_top_p(
11051105
def llama_sample_tail_free(
11061106
ctx: llama_context_p,
11071107
candidates, # type: _Pointer[llama_token_data_array]
1108-
z: c_float,
1109-
min_keep: c_size_t,
1108+
z: Union[c_float, float],
1109+
min_keep: Union[c_size_t, int],
11101110
):
11111111
return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep)
11121112

@@ -1125,8 +1125,8 @@ def llama_sample_tail_free(
11251125
def llama_sample_typical(
11261126
ctx: llama_context_p,
11271127
candidates, # type: _Pointer[llama_token_data_array]
1128-
p: c_float,
1129-
min_keep: c_size_t,
1128+
p: Union[c_float, float],
1129+
min_keep: Union[c_size_t, int],
11301130
):
11311131
return _lib.llama_sample_typical(ctx, candidates, p, min_keep)
11321132

@@ -1144,7 +1144,7 @@ def llama_sample_typical(
11441144
def llama_sample_temperature(
11451145
ctx: llama_context_p,
11461146
candidates, # type: _Pointer[llama_token_data_array]
1147-
temp: c_float,
1147+
temp: Union[c_float, float],
11481148
):
11491149
return _lib.llama_sample_temperature(ctx, candidates, temp)
11501150

@@ -1167,9 +1167,9 @@ def llama_sample_temperature(
11671167
def llama_sample_token_mirostat(
11681168
ctx: llama_context_p,
11691169
candidates, # type: _Pointer[llama_token_data_array]
1170-
tau: c_float,
1171-
eta: c_float,
1172-
m: c_int,
1170+
tau: Union[c_float, float],
1171+
eta: Union[c_float, float],
1172+
m: Union[c_int, int],
11731173
mu, # type: _Pointer[c_float]
11741174
) -> int:
11751175
return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu)
@@ -1195,8 +1195,8 @@ def llama_sample_token_mirostat(
11951195
def llama_sample_token_mirostat_v2(
11961196
ctx: llama_context_p,
11971197
candidates, # type: _Pointer[llama_token_data_array]
1198-
tau: c_float,
1199-
eta: c_float,
1198+
tau: Union[c_float, float],
1199+
eta: Union[c_float, float],
12001200
mu, # type: _Pointer[c_float]
12011201
) -> int:
12021202
return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu)
@@ -1289,5 +1289,5 @@ def llama_print_system_info() -> bytes:
12891289
_llama_initialized = False
12901290

12911291
if not _llama_initialized:
1292-
llama_backend_init(c_bool(False))
1292+
llama_backend_init(False)
12931293
_llama_initialized = True

0 commit comments

Comments
 (0)