Skip to content

Commit 4340e63

Browse files
committed
llama : add struct llama_kv_cache (wip) [no ci]
1 parent ae3c1db commit 4340e63

File tree

8 files changed

+428
-415
lines changed

8 files changed

+428
-415
lines changed

common/common.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -909,7 +909,9 @@ struct common_init_result common_init_from_params(common_params & params) {
909909
return iparams;
910910
}
911911

912-
if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
912+
llama_kv_cache * kv = llama_get_kv_cache(lctx);
913+
914+
if (params.ctx_shift && !llama_kv_cache_can_shift(kv)) {
913915
LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
914916
params.ctx_shift = false;
915917
}
@@ -1014,7 +1016,7 @@ struct common_init_result common_init_from_params(common_params & params) {
10141016
if (llama_model_has_decoder(model)) {
10151017
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
10161018
}
1017-
llama_kv_cache_clear(lctx);
1019+
llama_kv_cache_clear(kv);
10181020
llama_synchronize(lctx);
10191021
llama_perf_context_reset(lctx);
10201022
}

common/speculative.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,10 @@ llama_tokens common_speculative_gen_draft(
171171
llama_tokens result;
172172
result.reserve(params.n_draft);
173173

174+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
175+
174176
if (reuse_n == 0) {
175-
llama_kv_cache_clear(ctx);
177+
llama_kv_cache_clear(kv);
176178

177179
prompt.clear();
178180
} else {
@@ -191,14 +193,14 @@ llama_tokens common_speculative_gen_draft(
191193
}
192194

193195
if (reuse_i > 0) {
194-
llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
195-
llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
196+
llama_kv_cache_seq_rm (kv, 0, 0, reuse_i);
197+
llama_kv_cache_seq_add(kv, 0, reuse_i, -1, -reuse_i);
196198

197199
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
198200
}
199201

200202
if (reuse_n < (int) prompt.size()) {
201-
llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
203+
llama_kv_cache_seq_rm (kv, 0, reuse_n, -1);
202204

203205
prompt.erase(prompt.begin() + reuse_n, prompt.end());
204206
}

examples/embedding/embedding.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,11 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
3434

3535
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
3636
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
37-
const struct llama_model * model = llama_get_model(ctx);
37+
const llama_model * model = llama_get_model(ctx);
38+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
3839

3940
// clear previous kv_cache values (irrelevant for embeddings)
40-
llama_kv_cache_clear(ctx);
41+
llama_kv_cache_clear(kv);
4142

4243
// run model
4344
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);

include/llama.h

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ extern "C" {
6060
struct llama_model;
6161
struct llama_context;
6262
struct llama_sampler;
63+
struct llama_kv_cache;
6364

6465
typedef int32_t llama_pos;
6566
typedef int32_t llama_token;
@@ -467,8 +468,9 @@ extern "C" {
467468

468469
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
469470

470-
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
471-
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
471+
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); // TODO: remove const?
472+
LLAMA_API struct llama_kv_cache * llama_get_kv_cache( struct llama_context * ctx);
473+
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
472474

473475
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
474476
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
@@ -583,7 +585,7 @@ extern "C" {
583585
// KV cache
584586
//
585587

586-
// TODO: remove llama_kv_cache_view_* API
588+
// TODO: start using struct llama_kv_cache
587589

588590
// Information associated with an individual cell in the KV cache view.
589591
struct llama_kv_cache_view_cell {
@@ -638,41 +640,47 @@ extern "C" {
638640

639641
// Returns the number of tokens in the KV cache (slow, use only for debug)
640642
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
641-
LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
643+
LLAMA_API int32_t llama_kv_cache_n_tokens(const struct llama_kv_cache * kv);
644+
645+
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx),
646+
"use llama_kv_cache_n_tokens instead");
642647

643648
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
644-
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
649+
LLAMA_API int32_t llama_kv_cache_used_cells(const struct llama_kv_cache * kv);
650+
651+
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx),
652+
"use llama_kv_cache_used_cells instead");
645653

646654
// Clear the KV cache - both cell info is erased and KV data is zeroed
647655
LLAMA_API void llama_kv_cache_clear(
648-
struct llama_context * ctx);
656+
struct llama_kv_cache * kv);
649657

650658
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
651659
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
652660
// seq_id < 0 : match any sequence
653661
// p0 < 0 : [0, p1]
654662
// p1 < 0 : [p0, inf)
655663
LLAMA_API bool llama_kv_cache_seq_rm(
656-
struct llama_context * ctx,
657-
llama_seq_id seq_id,
658-
llama_pos p0,
659-
llama_pos p1);
664+
struct llama_kv_cache * kv,
665+
llama_seq_id seq_id,
666+
llama_pos p0,
667+
llama_pos p1);
660668

661669
// Copy all tokens that belong to the specified sequence to another sequence
662670
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
663671
// p0 < 0 : [0, p1]
664672
// p1 < 0 : [p0, inf)
665673
LLAMA_API void llama_kv_cache_seq_cp(
666-
struct llama_context * ctx,
667-
llama_seq_id seq_id_src,
668-
llama_seq_id seq_id_dst,
669-
llama_pos p0,
670-
llama_pos p1);
674+
struct llama_kv_cache * kv,
675+
llama_seq_id seq_id_src,
676+
llama_seq_id seq_id_dst,
677+
llama_pos p0,
678+
llama_pos p1);
671679

672680
// Removes all tokens that do not belong to the specified sequence
673681
LLAMA_API void llama_kv_cache_seq_keep(
674-
struct llama_context * ctx,
675-
llama_seq_id seq_id);
682+
struct llama_kv_cache * kv,
683+
llama_seq_id seq_id);
676684

677685
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
678686
// If the KV cache is RoPEd, the KV data is updated accordingly:
@@ -681,11 +689,11 @@ extern "C" {
681689
// p0 < 0 : [0, p1]
682690
// p1 < 0 : [p0, inf)
683691
LLAMA_API void llama_kv_cache_seq_add(
684-
struct llama_context * ctx,
685-
llama_seq_id seq_id,
686-
llama_pos p0,
687-
llama_pos p1,
688-
llama_pos delta);
692+
struct llama_kv_cache * kv,
693+
llama_seq_id seq_id,
694+
llama_pos p0,
695+
llama_pos p1,
696+
llama_pos delta);
689697

690698
// Integer division of the positions by factor of `d > 1`
691699
// If the KV cache is RoPEd, the KV data is updated accordingly:
@@ -694,31 +702,28 @@ extern "C" {
694702
// p0 < 0 : [0, p1]
695703
// p1 < 0 : [p0, inf)
696704
LLAMA_API void llama_kv_cache_seq_div(
697-
struct llama_context * ctx,
698-
llama_seq_id seq_id,
699-
llama_pos p0,
700-
llama_pos p1,
701-
int d);
705+
struct llama_kv_cache * kv,
706+
llama_seq_id seq_id,
707+
llama_pos p0,
708+
llama_pos p1,
709+
int d);
702710

703711
// Returns the largest position present in the KV cache for the specified sequence
704712
LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
705-
struct llama_context * ctx,
706-
llama_seq_id seq_id);
707-
708-
// TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache
709-
// how to avoid this?
713+
struct llama_kv_cache * kv,
714+
llama_seq_id seq_id);
710715

711716
// Defragment the KV cache
712717
// This will be applied:
713718
// - lazily on next llama_decode()
714719
// - explicitly with llama_kv_cache_update()
715-
LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx);
716-
717-
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
718-
LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
720+
LLAMA_API void llama_kv_cache_defrag(struct llama_kv_cache * kv);
719721

720722
// Check if the context supports KV cache shifting
721-
LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx);
723+
LLAMA_API bool llama_kv_cache_can_shift(const struct llama_kv_cache * kv);
724+
725+
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
726+
LLAMA_API void llama_update_kv_cache(struct llama_context * ctx, struct llama_kv_cache * kv);
722727

723728
//
724729
// State / sessions

src/llama-context.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -602,11 +602,15 @@ uint32_t llama_n_seq_max(const struct llama_context * ctx) {
602602
return ctx->kv_self.size;
603603
}
604604

605-
const struct llama_model * llama_get_model(const struct llama_context * ctx) {
605+
const llama_model * llama_get_model(const llama_context * ctx) {
606606
return &ctx->model;
607607
}
608608

609-
enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
609+
llama_kv_cache * llama_get_kv_cache(llama_context * ctx) {
610+
return &ctx->kv_self;
611+
}
612+
613+
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
610614
return ctx->cparams.pooling_type;
611615
}
612616

@@ -1142,7 +1146,7 @@ struct llama_data_read {
11421146
if (dest_seq_id != -1) {
11431147
// single sequence
11441148

1145-
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
1149+
kv_self.seq_rm(dest_seq_id, -1, -1);
11461150

11471151
llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
11481152
batch.n_tokens = cell_count;
@@ -1185,7 +1189,7 @@ struct llama_data_read {
11851189
return false;
11861190
}
11871191

1188-
llama_kv_cache_clear(kv_self);
1192+
kv_self.clear();
11891193

11901194
for (uint32_t i = 0; i < cell_count; ++i) {
11911195
llama_kv_cell & cell = kv_self.cells[i];
@@ -1362,9 +1366,9 @@ struct llama_data_read {
13621366

13631367
if (!res) {
13641368
if (seq_id == -1) {
1365-
llama_kv_cache_clear(ctx);
1369+
ctx->kv_self.clear();
13661370
} else {
1367-
llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
1371+
ctx->kv_self.seq_rm(seq_id, -1, -1);
13681372
}
13691373
throw std::runtime_error("failed to restore kv cache");
13701374
}

0 commit comments

Comments
 (0)