1717
1818
1919# Load the library
20- def _load_shared_library (lib_base_name ):
20+ def _load_shared_library (lib_base_name : str ):
2121 # Determine the file extension based on the platform
2222 if sys .platform .startswith ("linux" ):
2323 lib_ext = ".so"
@@ -252,7 +252,9 @@ def llama_get_state_size(ctx: llama_context_p) -> c_size_t:
252252# Copies the state to the specified destination address.
253253# Destination needs to have allocated enough memory.
254254# Returns the number of bytes copied
255- def llama_copy_state_data (ctx : llama_context_p , dest ) -> c_size_t :
255+ def llama_copy_state_data (
256+ ctx : llama_context_p , dest # type: Array[c_uint8]
257+ ) -> c_size_t :
256258 return _lib .llama_copy_state_data (ctx , dest )
257259
258260
@@ -262,7 +264,9 @@ def llama_copy_state_data(ctx: llama_context_p, dest) -> c_size_t:
262264
263265# Set the state reading from the specified address
264266# Returns the number of bytes read
265- def llama_set_state_data (ctx : llama_context_p , src ) -> c_size_t :
267+ def llama_set_state_data (
268+ ctx : llama_context_p , src # type: Array[c_uint8]
269+ ) -> c_size_t :
266270 return _lib .llama_set_state_data (ctx , src )
267271
268272
@@ -274,9 +278,9 @@ def llama_set_state_data(ctx: llama_context_p, src) -> c_size_t:
274278def llama_load_session_file (
275279 ctx : llama_context_p ,
276280 path_session : bytes ,
277- tokens_out ,
281+ tokens_out , # type: Array[llama_token]
278282 n_token_capacity : c_size_t ,
279- n_token_count_out ,
283+ n_token_count_out , # type: Array[c_size_t]
280284) -> c_size_t :
281285 return _lib .llama_load_session_file (
282286 ctx , path_session , tokens_out , n_token_capacity , n_token_count_out
@@ -294,7 +298,10 @@ def llama_load_session_file(
294298
295299
296300def llama_save_session_file (
297- ctx : llama_context_p , path_session : bytes , tokens , n_token_count : c_size_t
301+ ctx : llama_context_p ,
302+ path_session : bytes ,
303+ tokens , # type: Array[llama_token]
304+ n_token_count : c_size_t ,
298305) -> c_size_t :
299306 return _lib .llama_save_session_file (ctx , path_session , tokens , n_token_count )
300307
@@ -433,8 +440,8 @@ def llama_token_nl() -> llama_token:
433440# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
434441def llama_sample_repetition_penalty (
435442 ctx : llama_context_p ,
436- candidates ,
437- last_tokens_data ,
443+ candidates , # type: Array[llama_token_data]
444+ last_tokens_data , # type: Array[llama_token]
438445 last_tokens_size : c_int ,
439446 penalty : c_float ,
440447):
@@ -456,8 +463,8 @@ def llama_sample_repetition_penalty(
456463# @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
457464def llama_sample_frequency_and_presence_penalties (
458465 ctx : llama_context_p ,
459- candidates ,
460- last_tokens_data ,
466+ candidates , # type: Array[llama_token_data]
467+ last_tokens_data , # type: Array[llama_token]
461468 last_tokens_size : c_int ,
462469 alpha_frequency : c_float ,
463470 alpha_presence : c_float ,
@@ -484,7 +491,10 @@ def llama_sample_frequency_and_presence_penalties(
484491
485492
486493# @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
487- def llama_sample_softmax (ctx : llama_context_p , candidates ):
494+ def llama_sample_softmax (
495+ ctx : llama_context_p ,
496+ candidates # type: Array[llama_token_data]
497+ ):
488498 return _lib .llama_sample_softmax (ctx , candidates )
489499
490500
@@ -497,7 +507,10 @@ def llama_sample_softmax(ctx: llama_context_p, candidates):
497507
498508# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
499509def llama_sample_top_k (
500- ctx : llama_context_p , candidates , k : c_int , min_keep : c_size_t = c_size_t (1 )
510+ ctx : llama_context_p ,
511+ candidates , # type: Array[llama_token_data]
512+ k : c_int ,
513+ min_keep : c_size_t = c_size_t (1 )
501514):
502515 return _lib .llama_sample_top_k (ctx , candidates , k , min_keep )
503516
@@ -513,7 +526,10 @@ def llama_sample_top_k(
513526
514527# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
515528def llama_sample_top_p (
516- ctx : llama_context_p , candidates , p : c_float , min_keep : c_size_t = c_size_t (1 )
529+ ctx : llama_context_p ,
530+ candidates , # type: Array[llama_token_data]
531+ p : c_float ,
532+ min_keep : c_size_t = c_size_t (1 )
517533):
518534 return _lib .llama_sample_top_p (ctx , candidates , p , min_keep )
519535
@@ -529,7 +545,10 @@ def llama_sample_top_p(
529545
530546# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
531547def llama_sample_tail_free (
532- ctx : llama_context_p , candidates , z : c_float , min_keep : c_size_t = c_size_t (1 )
548+ ctx : llama_context_p ,
549+ candidates , # type: Array[llama_token_data]
550+ z : c_float ,
551+ min_keep : c_size_t = c_size_t (1 )
533552):
534553 return _lib .llama_sample_tail_free (ctx , candidates , z , min_keep )
535554
@@ -545,7 +564,10 @@ def llama_sample_tail_free(
545564
546565# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
547566def llama_sample_typical (
548- ctx : llama_context_p , candidates , p : c_float , min_keep : c_size_t = c_size_t (1 )
567+ ctx : llama_context_p ,
568+ candidates , # type: Array[llama_token_data]
569+ p : c_float ,
570+ min_keep : c_size_t = c_size_t (1 )
549571):
550572 return _lib .llama_sample_typical (ctx , candidates , p , min_keep )
551573
@@ -559,7 +581,11 @@ def llama_sample_typical(
559581_lib .llama_sample_typical .restype = None
560582
561583
562- def llama_sample_temperature (ctx : llama_context_p , candidates , temp : c_float ):
584+ def llama_sample_temperature (
585+ ctx : llama_context_p ,
586+ candidates , # type: Array[llama_token_data]
587+ temp : c_float
588+ ):
563589 return _lib .llama_sample_temperature (ctx , candidates , temp )
564590
565591
@@ -578,7 +604,12 @@ def llama_sample_temperature(ctx: llama_context_p, candidates, temp: c_float):
578604# @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
579605# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
580606def llama_sample_token_mirostat (
581- ctx : llama_context_p , candidates , tau : c_float , eta : c_float , m : c_int , mu
607+ ctx : llama_context_p ,
608+ candidates , # type: Array[llama_token_data]
609+ tau : c_float ,
610+ eta : c_float ,
611+ m : c_int ,
612+ mu # type: Array[c_float]
582613) -> llama_token :
583614 return _lib .llama_sample_token_mirostat (ctx , candidates , tau , eta , m , mu )
584615
@@ -600,7 +631,11 @@ def llama_sample_token_mirostat(
600631# @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
601632# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
602633def llama_sample_token_mirostat_v2 (
603- ctx : llama_context_p , candidates , tau : c_float , eta : c_float , mu
634+ ctx : llama_context_p ,
635+ candidates , # type: Array[llama_token_data]
636+ tau : c_float ,
637+ eta : c_float ,
638+ mu # type: Array[c_float]
604639) -> llama_token :
605640 return _lib .llama_sample_token_mirostat_v2 (ctx , candidates , tau , eta , mu )
606641
@@ -616,7 +651,10 @@ def llama_sample_token_mirostat_v2(
616651
617652
618653# @details Selects the token with the highest probability.
619- def llama_sample_token_greedy (ctx : llama_context_p , candidates ) -> llama_token :
654+ def llama_sample_token_greedy (
655+ ctx : llama_context_p ,
656+ candidates # type: Array[llama_token_data]
657+ ) -> llama_token :
620658 return _lib .llama_sample_token_greedy (ctx , candidates )
621659
622660
@@ -628,7 +666,10 @@ def llama_sample_token_greedy(ctx: llama_context_p, candidates) -> llama_token:
628666
629667
630668# @details Randomly selects a token from the candidates based on their probabilities.
631- def llama_sample_token (ctx : llama_context_p , candidates ) -> llama_token :
669+ def llama_sample_token (
670+ ctx : llama_context_p ,
671+ candidates # type: Array[llama_token_data]
672+ ) -> llama_token :
632673 return _lib .llama_sample_token (ctx , candidates )
633674
634675
0 commit comments