@@ -60,6 +60,7 @@ extern "C" {
60
60
struct llama_model ;
61
61
struct llama_context ;
62
62
struct llama_sampler ;
63
+ struct llama_kv_cache ;
63
64
64
65
typedef int32_t llama_pos;
65
66
typedef int32_t llama_token;
@@ -467,8 +468,9 @@ extern "C" {
467
468
468
469
DEPRECATED (LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
469
470
470
- LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
471
- LLAMA_API enum llama_pooling_type llama_pooling_type (const struct llama_context * ctx);
471
+ LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); // TODO: remove const?
472
+ LLAMA_API struct llama_kv_cache * llama_get_kv_cache ( struct llama_context * ctx);
473
+ LLAMA_API enum llama_pooling_type llama_pooling_type (const struct llama_context * ctx);
472
474
473
475
LLAMA_API const struct llama_vocab * llama_model_get_vocab (const struct llama_model * model);
474
476
LLAMA_API enum llama_rope_type llama_model_rope_type (const struct llama_model * model);
@@ -583,7 +585,7 @@ extern "C" {
583
585
// KV cache
584
586
//
585
587
586
- // TODO: remove llama_kv_cache_view_* API
588
+ // TODO: start using struct llama_kv_cache
587
589
588
590
// Information associated with an individual cell in the KV cache view.
589
591
struct llama_kv_cache_view_cell {
@@ -638,41 +640,47 @@ extern "C" {
638
640
639
641
// Returns the number of tokens in the KV cache (slow, use only for debug)
640
642
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
641
- LLAMA_API int32_t llama_get_kv_cache_token_count (const struct llama_context * ctx);
643
+ LLAMA_API int32_t llama_kv_cache_n_tokens (const struct llama_kv_cache * kv);
644
+
645
+ DEPRECATED (LLAMA_API int32_t llama_get_kv_cache_token_count (const struct llama_context * ctx),
646
+ "use llama_kv_cache_n_tokens instead");
642
647
643
648
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
644
- LLAMA_API int32_t llama_get_kv_cache_used_cells (const struct llama_context * ctx);
649
+ LLAMA_API int32_t llama_kv_cache_used_cells (const struct llama_kv_cache * kv);
650
+
651
+ DEPRECATED (LLAMA_API int32_t llama_get_kv_cache_used_cells (const struct llama_context * ctx),
652
+ "use llama_kv_cache_used_cells instead");
645
653
646
654
// Clear the KV cache - both cell info is erased and KV data is zeroed
647
655
LLAMA_API void llama_kv_cache_clear (
648
- struct llama_context * ctx );
656
+ struct llama_kv_cache * kv );
649
657
650
658
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
651
659
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
652
660
// seq_id < 0 : match any sequence
653
661
// p0 < 0 : [0, p1]
654
662
// p1 < 0 : [p0, inf)
655
663
LLAMA_API bool llama_kv_cache_seq_rm (
656
- struct llama_context * ctx ,
657
- llama_seq_id seq_id,
658
- llama_pos p0,
659
- llama_pos p1);
664
+ struct llama_kv_cache * kv ,
665
+ llama_seq_id seq_id,
666
+ llama_pos p0,
667
+ llama_pos p1);
660
668
661
669
// Copy all tokens that belong to the specified sequence to another sequence
662
670
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
663
671
// p0 < 0 : [0, p1]
664
672
// p1 < 0 : [p0, inf)
665
673
LLAMA_API void llama_kv_cache_seq_cp (
666
- struct llama_context * ctx ,
667
- llama_seq_id seq_id_src,
668
- llama_seq_id seq_id_dst,
669
- llama_pos p0,
670
- llama_pos p1);
674
+ struct llama_kv_cache * kv ,
675
+ llama_seq_id seq_id_src,
676
+ llama_seq_id seq_id_dst,
677
+ llama_pos p0,
678
+ llama_pos p1);
671
679
672
680
// Removes all tokens that do not belong to the specified sequence
673
681
LLAMA_API void llama_kv_cache_seq_keep (
674
- struct llama_context * ctx ,
675
- llama_seq_id seq_id);
682
+ struct llama_kv_cache * kv ,
683
+ llama_seq_id seq_id);
676
684
677
685
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
678
686
// If the KV cache is RoPEd, the KV data is updated accordingly:
@@ -681,11 +689,11 @@ extern "C" {
681
689
// p0 < 0 : [0, p1]
682
690
// p1 < 0 : [p0, inf)
683
691
LLAMA_API void llama_kv_cache_seq_add (
684
- struct llama_context * ctx ,
685
- llama_seq_id seq_id,
686
- llama_pos p0,
687
- llama_pos p1,
688
- llama_pos delta);
692
+ struct llama_kv_cache * kv ,
693
+ llama_seq_id seq_id,
694
+ llama_pos p0,
695
+ llama_pos p1,
696
+ llama_pos delta);
689
697
690
698
// Integer division of the positions by factor of `d > 1`
691
699
// If the KV cache is RoPEd, the KV data is updated accordingly:
@@ -694,31 +702,28 @@ extern "C" {
694
702
// p0 < 0 : [0, p1]
695
703
// p1 < 0 : [p0, inf)
696
704
LLAMA_API void llama_kv_cache_seq_div (
697
- struct llama_context * ctx ,
698
- llama_seq_id seq_id,
699
- llama_pos p0,
700
- llama_pos p1,
701
- int d);
705
+ struct llama_kv_cache * kv ,
706
+ llama_seq_id seq_id,
707
+ llama_pos p0,
708
+ llama_pos p1,
709
+ int d);
702
710
703
711
// Returns the largest position present in the KV cache for the specified sequence
704
712
LLAMA_API llama_pos llama_kv_cache_seq_pos_max (
705
- struct llama_context * ctx,
706
- llama_seq_id seq_id);
707
-
708
- // TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache
709
- // how to avoid this?
713
+ struct llama_kv_cache * kv,
714
+ llama_seq_id seq_id);
710
715
711
716
// Defragment the KV cache
712
717
// This will be applied:
713
718
// - lazily on next llama_decode()
714
719
// - explicitly with llama_kv_cache_update()
715
- LLAMA_API void llama_kv_cache_defrag (struct llama_context * ctx);
716
-
717
- // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
718
- LLAMA_API void llama_kv_cache_update (struct llama_context * ctx);
720
+ LLAMA_API void llama_kv_cache_defrag (struct llama_kv_cache * kv);
719
721
720
722
// Check if the context supports KV cache shifting
721
- LLAMA_API bool llama_kv_cache_can_shift (struct llama_context * ctx);
723
+ LLAMA_API bool llama_kv_cache_can_shift (const struct llama_kv_cache * kv);
724
+
725
+ // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
726
+ LLAMA_API void llama_update_kv_cache (struct llama_context * ctx, struct llama_kv_cache * kv);
722
727
723
728
//
724
729
// State / sessions
0 commit comments