@@ -141,6 +141,11 @@ class llama_context_params(Structure):
141141LLAMA_FTYPE_MOSTLY_Q5_0 = ctypes .c_int (8 ) # except 1d tensors
142142LLAMA_FTYPE_MOSTLY_Q5_1 = ctypes .c_int (9 ) # except 1d tensors
143143
144+ # Misc
145+ c_float_p = POINTER (c_float )
146+ c_uint8_p = POINTER (c_uint8 )
147+ c_size_t_p = POINTER (c_size_t )
148+
144149# Functions
145150
146151
@@ -257,7 +262,7 @@ def llama_copy_state_data(ctx: llama_context_p, dest: Array[c_uint8]) -> c_size_
257262 return _lib .llama_copy_state_data (ctx , dest )
258263
259264
260- _lib .llama_copy_state_data .argtypes = [llama_context_p , POINTER ( c_uint8 ) ]
265+ _lib .llama_copy_state_data .argtypes = [llama_context_p , c_uint8_p ]
261266_lib .llama_copy_state_data .restype = c_size_t
262267
263268
@@ -269,7 +274,7 @@ def llama_set_state_data(
269274 return _lib .llama_set_state_data (ctx , src )
270275
271276
272- _lib .llama_set_state_data .argtypes = [llama_context_p , POINTER ( c_uint8 ) ]
277+ _lib .llama_set_state_data .argtypes = [llama_context_p , c_uint8_p ]
273278_lib .llama_set_state_data .restype = c_size_t
274279
275280
@@ -291,7 +296,7 @@ def llama_load_session_file(
291296 c_char_p ,
292297 llama_token_p ,
293298 c_size_t ,
294- POINTER ( c_size_t ) ,
299+ c_size_t_p ,
295300]
296301_lib .llama_load_session_file .restype = c_size_t
297302
@@ -340,7 +345,7 @@ def llama_eval(
340345def llama_tokenize (
341346 ctx : llama_context_p ,
342347 text : bytes ,
343- tokens , # type : Array[llama_token]
348+ tokens : Array [llama_token ],
344349 n_max_tokens : c_int ,
345350 add_bos : c_bool ,
346351) -> c_int :
@@ -385,7 +390,7 @@ def llama_get_logits(ctx: llama_context_p):
385390
386391
387392_lib .llama_get_logits .argtypes = [llama_context_p ]
388- _lib .llama_get_logits .restype = POINTER ( c_float )
393+ _lib .llama_get_logits .restype = c_float_p
389394
390395
391396# Get the embeddings for the input
@@ -395,7 +400,7 @@ def llama_get_embeddings(ctx: llama_context_p):
395400
396401
397402_lib .llama_get_embeddings .argtypes = [llama_context_p ]
398- _lib .llama_get_embeddings .restype = POINTER ( c_float )
403+ _lib .llama_get_embeddings .restype = c_float_p
399404
400405
401406# Token Id -> String. Uses the vocabulary in the provided context
@@ -614,7 +619,7 @@ def llama_sample_token_mirostat(
614619 c_float ,
615620 c_float ,
616621 c_int ,
617- POINTER ( c_float ) ,
622+ c_float_p ,
618623]
619624_lib .llama_sample_token_mirostat .restype = llama_token
620625
@@ -639,7 +644,7 @@ def llama_sample_token_mirostat_v2(
639644 llama_token_data_array_p ,
640645 c_float ,
641646 c_float ,
642- POINTER ( c_float ) ,
647+ c_float_p ,
643648]
644649_lib .llama_sample_token_mirostat_v2 .restype = llama_token
645650
0 commit comments