Skip to content

Commit 909c5c0

Browse files
committed
Revert "llama : remove deprecated llama_kv_self API (ggml-org#15472)"
This reverts commit cd36b5e.
1 parent 1885750 commit 909c5c0

File tree

3 files changed

+297
-6
lines changed

3 files changed

+297
-6
lines changed

include/llama.h

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,111 @@ extern "C" {
666666
// Check if the memory supports shifting
667667
LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);
668668

669+
//
670+
// KV cache for self-attention (TODO: deprecate in favor of llama_memory)
671+
//
672+
673+
// Returns the number of tokens in the KV cache (slow, use only for debug)
674+
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
675+
DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx),
676+
"Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
677+
678+
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
679+
DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx),
680+
"Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
681+
682+
// Clear the KV cache - both cell info is erased and KV data is zeroed
683+
DEPRECATED(LLAMA_API void llama_kv_self_clear(
684+
struct llama_context * ctx),
685+
"Use llama_memory_clear() instead");
686+
687+
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
688+
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
689+
// seq_id < 0 : match any sequence
690+
// p0 < 0 : [0, p1]
691+
// p1 < 0 : [p0, inf)
692+
DEPRECATED(LLAMA_API bool llama_kv_self_seq_rm(
693+
struct llama_context * ctx,
694+
llama_seq_id seq_id,
695+
llama_pos p0,
696+
llama_pos p1),
697+
"Use llama_memory_seq_rm() instead");
698+
699+
// Copy all tokens that belong to the specified sequence to another sequence
700+
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
701+
// p0 < 0 : [0, p1]
702+
// p1 < 0 : [p0, inf)
703+
DEPRECATED(LLAMA_API void llama_kv_self_seq_cp(
704+
struct llama_context * ctx,
705+
llama_seq_id seq_id_src,
706+
llama_seq_id seq_id_dst,
707+
llama_pos p0,
708+
llama_pos p1),
709+
"Use llama_memory_seq_cp() instead");
710+
711+
// Removes all tokens that do not belong to the specified sequence
712+
DEPRECATED(LLAMA_API void llama_kv_self_seq_keep(
713+
struct llama_context * ctx,
714+
llama_seq_id seq_id),
715+
"Use llama_memory_seq_keep() instead");
716+
717+
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
718+
// If the KV cache is RoPEd, the KV data is updated accordingly:
719+
// - lazily on next llama_decode()
720+
// p0 < 0 : [0, p1]
721+
// p1 < 0 : [p0, inf)
722+
DEPRECATED(LLAMA_API void llama_kv_self_seq_add(
723+
struct llama_context * ctx,
724+
llama_seq_id seq_id,
725+
llama_pos p0,
726+
llama_pos p1,
727+
llama_pos delta),
728+
"Use llama_memory_seq_add() instead");
729+
730+
// Integer division of the positions by factor of `d > 1`
731+
// If the KV cache is RoPEd, the KV data is updated accordingly:
732+
// - lazily on next llama_decode()
733+
// p0 < 0 : [0, p1]
734+
// p1 < 0 : [p0, inf)
735+
DEPRECATED(LLAMA_API void llama_kv_self_seq_div(
736+
struct llama_context * ctx,
737+
llama_seq_id seq_id,
738+
llama_pos p0,
739+
llama_pos p1,
740+
int d),
741+
"Use llama_memory_seq_div() instead");
742+
743+
// Returns the smallest position present in the KV cache for the specified sequence
744+
// This is typically non-zero only for SWA caches
745+
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
746+
// Return -1 if the sequence is empty
747+
DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_min(
748+
struct llama_context * ctx,
749+
llama_seq_id seq_id),
750+
"Use llama_memory_seq_pos_min() instead");
751+
752+
// Returns the largest position present in the KV cache for the specified sequence
753+
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
754+
// Return -1 if the sequence is empty
755+
DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_max(
756+
struct llama_context * ctx,
757+
llama_seq_id seq_id),
758+
"Use llama_memory_seq_pos_max() instead");
759+
760+
// Defragment the KV cache
761+
// This will be applied:
762+
// - lazily on next llama_decode()
763+
DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx),
764+
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
765+
766+
// Check if the context supports KV cache shifting
767+
DEPRECATED(LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx),
768+
"use llama_memory_can_shift() instead");
769+
770+
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
771+
DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx),
772+
"simply remove this call, updates are applied lazily on the next llama_decode()");
773+
669774
//
670775
// State / sessions
671776
//

src/llama-context.cpp

Lines changed: 185 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ llama_context::llama_context(
9393
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
9494
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
9595
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
96-
// TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory
96+
// TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
9797
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
9898
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
9999
cparams.n_batch = GGML_KQ_MASK_PAD;
@@ -439,12 +439,26 @@ llama_memory_t llama_context::get_memory() const {
439439
return memory.get();
440440
}
441441

442-
bool llama_context::memory_update(bool optimize) {
442+
// deprecated
443+
void llama_context::kv_self_defrag_sched() {
444+
if (!memory) {
445+
return;
446+
}
447+
448+
memory_force_optimize = true;
449+
}
450+
451+
// deprecated
452+
bool llama_context::kv_self_update(bool optimize) {
443453
if (!memory) {
444454
return false;
445455
}
446456

447457
{
458+
// TODO: remove in the future
459+
optimize |= memory_force_optimize;
460+
memory_force_optimize = false;
461+
448462
const auto mctx = memory->init_update(this, optimize);
449463
switch (mctx->get_status()) {
450464
case LLAMA_MEMORY_STATUS_SUCCESS:
@@ -979,7 +993,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
979993
bool did_optimize = false;
980994

981995
// handle any pending defrags/shifts
982-
memory_update(false);
996+
kv_self_update(false);
983997

984998
llama_memory_context_ptr mctx;
985999

@@ -1004,7 +1018,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
10041018
if (!did_optimize) {
10051019
did_optimize = true;
10061020

1007-
if (memory_update(true)) {
1021+
if (kv_self_update(true)) {
10081022
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
10091023

10101024
continue;
@@ -2324,6 +2338,11 @@ const llama_model * llama_get_model(const llama_context * ctx) {
23242338
return &ctx->get_model();
23252339
}
23262340

2341+
// deprecated
2342+
void llama_kv_self_update(llama_context * ctx) {
2343+
ctx->kv_self_update(false);
2344+
}
2345+
23272346
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
23282347
return ctx->pooling_type();
23292348
}
@@ -2541,6 +2560,168 @@ bool llama_memory_can_shift(llama_memory_t mem) {
25412560
return mem->get_can_shift();
25422561
}
25432562

2563+
//
2564+
// kv cache
2565+
//
2566+
2567+
// deprecated
2568+
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2569+
const auto * kv = llama_get_memory(ctx);
2570+
if (!kv) {
2571+
return 0;
2572+
}
2573+
2574+
int32_t res = 0;
2575+
2576+
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2577+
const llama_pos p0 = kv->seq_pos_min(s);
2578+
const llama_pos p1 = kv->seq_pos_max(s);
2579+
2580+
if (p0 >= 0) {
2581+
res += (p1 - p0) + 1;
2582+
}
2583+
}
2584+
2585+
return res;
2586+
}
2587+
2588+
// deprecated
2589+
// note: this is the same as above - will be removed anyway, so it's ok
2590+
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2591+
const auto * kv = llama_get_memory(ctx);
2592+
if (!kv) {
2593+
return 0;
2594+
}
2595+
2596+
int32_t res = 0;
2597+
2598+
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2599+
const llama_pos p0 = kv->seq_pos_min(s);
2600+
const llama_pos p1 = kv->seq_pos_max(s);
2601+
2602+
if (p0 >= 0) {
2603+
res += (p1 - p0) + 1;
2604+
}
2605+
}
2606+
2607+
return res;
2608+
}
2609+
2610+
// deprecated
2611+
void llama_kv_self_clear(llama_context * ctx) {
2612+
auto * kv = llama_get_memory(ctx);
2613+
if (!kv) {
2614+
return;
2615+
}
2616+
2617+
llama_memory_clear(kv, true);
2618+
}
2619+
2620+
// deprecated
2621+
bool llama_kv_self_seq_rm(
2622+
llama_context * ctx,
2623+
llama_seq_id seq_id,
2624+
llama_pos p0,
2625+
llama_pos p1) {
2626+
auto * kv = llama_get_memory(ctx);
2627+
if (!kv) {
2628+
return true;
2629+
}
2630+
2631+
return llama_memory_seq_rm(kv, seq_id, p0, p1);
2632+
}
2633+
2634+
// deprecated
2635+
void llama_kv_self_seq_cp(
2636+
llama_context * ctx,
2637+
llama_seq_id seq_id_src,
2638+
llama_seq_id seq_id_dst,
2639+
llama_pos p0,
2640+
llama_pos p1) {
2641+
auto * kv = llama_get_memory(ctx);
2642+
if (!kv) {
2643+
return;
2644+
}
2645+
2646+
llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
2647+
}
2648+
2649+
// deprecated
2650+
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2651+
auto * kv = llama_get_memory(ctx);
2652+
if (!kv) {
2653+
return;
2654+
}
2655+
2656+
llama_memory_seq_keep(kv, seq_id);
2657+
}
2658+
2659+
// deprecated
2660+
void llama_kv_self_seq_add(
2661+
llama_context * ctx,
2662+
llama_seq_id seq_id,
2663+
llama_pos p0,
2664+
llama_pos p1,
2665+
llama_pos delta) {
2666+
auto * kv = llama_get_memory(ctx);
2667+
if (!kv) {
2668+
return;
2669+
}
2670+
2671+
llama_memory_seq_add(kv, seq_id, p0, p1, delta);
2672+
}
2673+
2674+
// deprecated
2675+
void llama_kv_self_seq_div(
2676+
llama_context * ctx,
2677+
llama_seq_id seq_id,
2678+
llama_pos p0,
2679+
llama_pos p1,
2680+
int d) {
2681+
auto * kv = llama_get_memory(ctx);
2682+
if (!kv) {
2683+
return;
2684+
}
2685+
2686+
llama_memory_seq_div(kv, seq_id, p0, p1, d);
2687+
}
2688+
2689+
// deprecated
2690+
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
2691+
auto * kv = llama_get_memory(ctx);
2692+
if (!kv) {
2693+
return -1;
2694+
}
2695+
2696+
return llama_memory_seq_pos_min(kv, seq_id);
2697+
}
2698+
2699+
// deprecated
2700+
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2701+
auto * kv = llama_get_memory(ctx);
2702+
if (!kv) {
2703+
return -1;
2704+
}
2705+
2706+
return llama_memory_seq_pos_max(kv, seq_id);
2707+
}
2708+
2709+
// deprecated
2710+
void llama_kv_self_defrag(llama_context * ctx) {
2711+
// force defrag
2712+
ctx->kv_self_defrag_sched();
2713+
}
2714+
2715+
// deprecated
2716+
bool llama_kv_self_can_shift(const llama_context * ctx) {
2717+
auto * kv = llama_get_memory(ctx);
2718+
if (!kv) {
2719+
return false;
2720+
}
2721+
2722+
return llama_memory_can_shift(kv);
2723+
}
2724+
25442725
// llama state API
25452726

25462727
// deprecated

src/llama-context.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ struct llama_context {
4646

4747
llama_memory_t get_memory() const;
4848

49-
// return true if the memory was updated
50-
bool memory_update(bool optimize);
49+
// return true of the KV cache was updated
50+
// TODO: remove
51+
bool kv_self_update(bool optimize);
52+
void kv_self_defrag_sched();
5153

5254
enum llama_pooling_type pooling_type() const;
5355

@@ -228,6 +230,9 @@ struct llama_context {
228230

229231
std::unique_ptr<llama_memory_i> memory;
230232

233+
// TODO: temporary, until the llama_kv_self_defrag() API is removed
234+
bool memory_force_optimize = false;
235+
231236
// decode output (2-dimensional array: [n_outputs][n_vocab])
232237
size_t logits_size = 0; // capacity (of floats) for logits
233238
float * logits = nullptr;

0 commit comments

Comments
 (0)