@@ -64,7 +64,10 @@ extern "C" {
6464 struct llama_model ;
6565 struct llama_context ;
6666 struct llama_sampler ;
67- struct llama_kv_cache ;
67+
68+ typedef struct llama_memory_i * llama_memory_t ;
69+
70+ struct llama_kv_cache ; // DEPRECATED (use llama_memory instead)
6871
6972 typedef int32_t llama_pos;
7073 typedef int32_t llama_token;
@@ -496,9 +499,11 @@ extern "C" {
496499 DEPRECATED (LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
497500
498501 LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
499- LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx);
502+ LLAMA_API llama_memory_t llama_get_memory ( const struct llama_context * ctx);
500503 LLAMA_API enum llama_pooling_type llama_pooling_type (const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
501504
505+ DEPRECATED (LLAMA_API struct llama_kv_cache * llama_get_kv_self (struct llama_context * ctx), "use llama_get_memory instead");
506+
502507 LLAMA_API const struct llama_vocab * llama_model_get_vocab (const struct llama_model * model);
503508 LLAMA_API enum llama_rope_type llama_model_rope_type (const struct llama_model * model);
504509
@@ -512,6 +517,13 @@ extern "C" {
512517 // Get the model's RoPE frequency scaling factor
513518 LLAMA_API float llama_model_rope_freq_scale_train (const struct llama_model * model);
514519
520+ // Returns the number of classifier outputs (only valid for classifier models)
521+ // Undefined behavior for non-classifier models
522+ LLAMA_API uint32_t llama_model_n_cls_out (const struct llama_model * model);
523+
524+ // Returns label of classifier output by index (<n_cls_out). Returns nullptr if no label provided
525+ LLAMA_API const char * llama_model_cls_label (const struct llama_model * model, uint32_t i);
526+
515527 LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_vocab * vocab);
516528
517529 LLAMA_API int32_t llama_vocab_n_tokens (const struct llama_vocab * vocab);
@@ -612,7 +624,78 @@ extern "C" {
612624 int32_t il_end);
613625
614626 //
615- // KV cache
627+ // Memory
628+ //
629+
630+ // Clear the memory contents
631+ LLAMA_API void llama_memory_clear (llama_memory_t mem);
632+
633+ // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
634+ // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
635+ // seq_id < 0 : match any sequence
636+ // p0 < 0 : [0, p1]
637+ // p1 < 0 : [p0, inf)
638+ LLAMA_API bool llama_memory_seq_rm (
639+ llama_memory_t mem,
640+ llama_seq_id seq_id,
641+ llama_pos p0,
642+ llama_pos p1);
643+
644+ // Copy all tokens that belong to the specified sequence to another sequence
645+ // p0 < 0 : [0, p1]
646+ // p1 < 0 : [p0, inf)
647+ LLAMA_API void llama_memory_seq_cp (
648+ llama_memory_t mem,
649+ llama_seq_id seq_id_src,
650+ llama_seq_id seq_id_dst,
651+ llama_pos p0,
652+ llama_pos p1);
653+
654+ // Removes all tokens that do not belong to the specified sequence
655+ LLAMA_API void llama_memory_seq_keep (
656+ llama_memory_t mem,
657+ llama_seq_id seq_id);
658+
659+ // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
660+ // p0 < 0 : [0, p1]
661+ // p1 < 0 : [p0, inf)
662+ LLAMA_API void llama_memory_seq_add (
663+ llama_memory_t mem,
664+ llama_seq_id seq_id,
665+ llama_pos p0,
666+ llama_pos p1,
667+ llama_pos delta);
668+
669+ // Integer division of the positions by factor of `d > 1`
670+ // p0 < 0 : [0, p1]
671+ // p1 < 0 : [p0, inf)
672+ LLAMA_API void llama_memory_seq_div (
673+ llama_memory_t mem,
674+ llama_seq_id seq_id,
675+ llama_pos p0,
676+ llama_pos p1,
677+ int d);
678+
679+ // Returns the smallest position present in the memory for the specified sequence
680+ // This is typically non-zero only for SWA caches
681+ // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
682+ // Return -1 if the sequence is empty
683+ LLAMA_API llama_pos llama_memory_seq_pos_min (
684+ llama_memory_t mem,
685+ llama_seq_id seq_id);
686+
687+ // Returns the largest position present in the memory for the specified sequence
688+ // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
689+ // Return -1 if the sequence is empty
690+ LLAMA_API llama_pos llama_memory_seq_pos_max (
691+ llama_memory_t mem,
692+ llama_seq_id seq_id);
693+
694+ // Check if the memory supports shifting
695+ LLAMA_API bool llama_memory_can_shift (llama_memory_t mem);
696+
697+ //
698+ // KV cache for self-attention (TODO: deprecate in favor of llama_memory)
616699 //
617700
618701 // Returns the number of tokens in the KV cache (slow, use only for debug)
@@ -626,7 +709,7 @@ extern "C" {
626709
627710 // Clear the KV cache - both cell info is erased and KV data is zeroed
628711 LLAMA_API void llama_kv_self_clear (
629- struct llama_context * ctx);
712+ struct llama_context * ctx);
630713
631714 // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
632715 // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
@@ -697,22 +780,22 @@ extern "C" {
697780 // Defragment the KV cache
698781 // This will be applied:
699782 // - lazily on next llama_decode()
700- LLAMA_API DEPRECATED (void llama_kv_self_defrag (struct llama_context * ctx),
783+ DEPRECATED (LLAMA_API void llama_kv_self_defrag (struct llama_context * ctx),
701784 "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
702785
703786 // Check if the context supports KV cache shifting
704787 LLAMA_API bool llama_kv_self_can_shift (const struct llama_context * ctx);
705788
706789 // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
707- LLAMA_API DEPRECATED (void llama_kv_self_update (struct llama_context * ctx),
790+ DEPRECATED (LLAMA_API void llama_kv_self_update (struct llama_context * ctx),
708791 "simply remove this call, updates are applied lazily on the next llama_decode()");
709792
710793 //
711794 // State / sessions
712795 //
713796
714797 // Returns the *actual* size in bytes of the state
715- // (logits, embedding and kv_cache )
798+ // (logits, embedding and memory )
716799 // Only use when saving the state, not when restoring it, otherwise the size may be too small.
717800 LLAMA_API size_t llama_state_get_size (struct llama_context * ctx);
718801 LLAMA_API DEPRECATED (size_t llama_get_state_size (struct llama_context * ctx),
@@ -768,12 +851,12 @@ extern "C" {
768851 size_t n_token_count),
769852 "use llama_state_save_file instead");
770853
771- // Get the exact size needed to copy the KV cache of a single sequence
854+ // Get the exact size needed to copy the state of a single sequence
772855 LLAMA_API size_t llama_state_seq_get_size (
773856 struct llama_context * ctx,
774857 llama_seq_id seq_id);
775858
776- // Copy the KV cache of a single sequence into the specified buffer
859+ // Copy the state of a single sequence into the specified buffer
777860 LLAMA_API size_t llama_state_seq_get_data (
778861 struct llama_context * ctx,
779862 uint8_t * dst,
@@ -839,16 +922,16 @@ extern "C" {
839922 // For encode-decoder contexts, processes the batch using the encoder.
840923 // Can store the encoder output internally for later use by the decoder's cross-attention layers.
841924 // 0 - success
842- // < 0 - error. the KV cache state is restored to the state before this call
925+ // < 0 - error. the memory state is restored to the state before this call
843926 LLAMA_API int32_t llama_encode (
844927 struct llama_context * ctx,
845928 struct llama_batch batch);
846929
847930 // Process a batch of tokens.
848- // Requires KV cache .
931+ // Requires the context to have a memory .
849932 // For encode-decoder contexts, processes the batch using the decoder.
850933 // Positive return values does not mean a fatal error, but rather a warning.
851- // Upon non-zero return values, the KV cache state is restored to the state before this call
934+ // Upon non-zero return values, the memory state is restored to the state before this call
852935 // 0 - success
853936 // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
854937 // 2 - aborted
@@ -919,7 +1002,7 @@ extern "C" {
9191002
9201003 // Get the embeddings for a sequence id
9211004 // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
922- // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1 ] with the rank of the sequence
1005+ // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[n_cls_out ] with the rank(s) of the sequence
9231006 // otherwise: float[n_embd] (1-dimensional)
9241007 LLAMA_API float * llama_get_embeddings_seq (struct llama_context * ctx, llama_seq_id seq_id);
9251008
0 commit comments