Skip to content

Commit 5f11a55

Browse files
committed
kv-cache : remove llama_kv_cache_i
1 parent f5cedbc commit 5f11a55

File tree

5 files changed

+330
-339
lines changed

5 files changed

+330
-339
lines changed

src/llama-context.cpp

Lines changed: 305 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2533,7 +2533,7 @@ void llama_context_kv_self::kv_self_update() {
25332533

25342534
auto * gf = graph_init();
25352535

2536-
kv_self.build_shift(ctx_compute.get(), gf, this);
2536+
build_kv_self_shift(ctx_compute.get(), gf);
25372537

25382538
ggml_backend_sched_alloc_graph(sched.get(), gf);
25392539

@@ -2559,7 +2559,7 @@ void llama_context_kv_self::kv_self_update() {
25592559

25602560
auto * gf = graph_init();
25612561

2562-
kv_self.build_defrag(ctx_compute.get(), gf, max_nodes(), !cparams.flash_attn);
2562+
build_kv_self_defrag(ctx_compute.get(), gf);
25632563

25642564
ggml_backend_sched_alloc_graph(sched.get(), gf);
25652565

@@ -2817,6 +2817,309 @@ ggml_tensor * llama_context_kv_self::build_attn_soft_max(
28172817
return ggml_soft_max_ext(ctx0, kq, inp_KQ_mask_cnv, kq_scale, hparams.f_max_alibi_bias);
28182818
}
28192819

2820+
void llama_context_kv_self::build_kv_self_shift(
2821+
ggml_context * ctx0,
2822+
ggml_cgraph * gf) {
2823+
const auto & hparams = model.hparams;
2824+
2825+
const auto & n_layer = hparams.n_layer;
2826+
2827+
const auto & n_embd_head_k = hparams.n_embd_head_k;
2828+
//const auto & n_embd_head_v = hparams.n_embd_head_v;
2829+
2830+
//GGML_ASSERT(kv_self.size == n_ctx);
2831+
2832+
ggml_tensor * inp_k_shift = build_inp_k_shift(ctx0);
2833+
2834+
for (uint32_t il = 0; il < n_layer; ++il) {
2835+
const int64_t n_head_kv = hparams.n_head_kv(il);
2836+
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
2837+
2838+
struct ggml_tensor * rope_factors = build_rope_factors(il);
2839+
2840+
struct ggml_tensor * k =
2841+
ggml_view_3d(ctx0, kv_self.k_l[il],
2842+
n_embd_head_k, n_head_kv, kv_self.size,
2843+
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
2844+
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
2845+
0);
2846+
2847+
ggml_tensor * cur = build_rope_shift(ctx0, k, inp_k_shift, rope_factors, kv_self.k_l[il]->buffer);
2848+
2849+
ggml_build_forward_expand(gf, cur);
2850+
}
2851+
}
2852+
2853+
void llama_context_kv_self::build_kv_self_defrag(
2854+
ggml_context * ctx0,
2855+
ggml_cgraph * gf) {
2856+
const auto & hparams = model.hparams;
2857+
2858+
const uint32_t n_layer = hparams.n_layer;
2859+
2860+
const uint32_t n_kv = kv_self.cell_max();
2861+
const uint32_t n_used = kv_self.used;
2862+
2863+
assert(n_used <= n_kv);
2864+
2865+
//const int64_t t_start = ggml_time_us();
2866+
2867+
// number of cells moved
2868+
uint32_t n_moves = 0;
2869+
2870+
// each move requires 6*n_layer tensors (see build_kv_self_defrag)
2871+
// - source view, destination view, copy operation
2872+
// - x2 for keys and values
2873+
//const uint32_t max_moves = max_nodes()/(6*n_layer);
2874+
// TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
2875+
const uint32_t max_moves = (max_nodes() - 2*n_layer)/(6*n_layer);
2876+
2877+
// determine which KV cells to move where
2878+
//
2879+
// cell i moves to ids[i]
2880+
//
2881+
// if ids[i] == i || ids[i] == n_kv, then cell i is not moved
2882+
//
2883+
std::vector<uint32_t> ids(n_kv, n_kv);
2884+
2885+
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
2886+
const auto & cell0 = kv_self.cells[i0];
2887+
2888+
if (!cell0.is_empty()) {
2889+
ids[i0] = i0;
2890+
2891+
continue;
2892+
}
2893+
2894+
// found a hole - fill it with data from the end of the cache
2895+
2896+
uint32_t nh = 1;
2897+
2898+
// determine the size of the hole
2899+
while (i0 + nh < n_used && kv_self.cells[i0 + nh].is_empty()) {
2900+
nh++;
2901+
}
2902+
2903+
uint32_t nf = 0;
2904+
uint32_t is = n_kv - 1;
2905+
2906+
// starting from the end, find nh non-empty cells
2907+
for (; is > i0; --is) {
2908+
const auto & cell1 = kv_self.cells[is];
2909+
2910+
if (cell1.is_empty() || ids[is] != n_kv) {
2911+
continue;
2912+
}
2913+
2914+
// non-empty cell which is not yet moved
2915+
nf++;
2916+
2917+
if (nf == nh) {
2918+
break;
2919+
}
2920+
}
2921+
2922+
// this can only happen if `n_used` is not accurate, which would be a bug
2923+
GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
2924+
2925+
nf = 0;
2926+
2927+
uint32_t i1 = is;
2928+
2929+
// are we moving a continuous block of memory?
2930+
bool cont = false;
2931+
2932+
// should we stop searching for the next move?
2933+
bool stop = false;
2934+
2935+
// go back and move the nf cells to the hole
2936+
for (; i1 < n_kv; ++i1) {
2937+
auto & cell1 = kv_self.cells[i1];
2938+
2939+
if (cell1.is_empty() || ids[i1] != n_kv) {
2940+
if (n_moves == max_moves) {
2941+
stop = true;
2942+
break;
2943+
}
2944+
2945+
cont = false;
2946+
continue;
2947+
}
2948+
2949+
// this cell goes to (i0 + nf)
2950+
ids[i1] = i0 + nf;
2951+
2952+
// move the cell meta data
2953+
kv_self.cells[i0 + nf] = cell1;
2954+
2955+
// clear the old cell and move the head there
2956+
cell1 = llama_kv_cell();
2957+
kv_self.head = n_used;
2958+
2959+
if (!cont) {
2960+
n_moves++;
2961+
cont = true;
2962+
}
2963+
2964+
nf++;
2965+
2966+
if (nf == nh) {
2967+
break;
2968+
}
2969+
}
2970+
2971+
if (stop || n_moves == max_moves) {
2972+
break;
2973+
}
2974+
2975+
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
2976+
2977+
i0 += nh - 1;
2978+
}
2979+
2980+
if (n_moves == 0) {
2981+
return;
2982+
}
2983+
2984+
//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
2985+
2986+
//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
2987+
2988+
#if 0
2989+
// CPU defrag
2990+
//
2991+
// TODO: optimizations are possible:
2992+
// - multiple threads
2993+
// - avoid copying to the host memory when already there
2994+
//
2995+
// likely not worth the effort, as we have ggml_graph based defrag
2996+
//
2997+
2998+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
2999+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
3000+
3001+
const uint32_t kv_size = size;
3002+
3003+
std::vector<uint8_t> buf_k;
3004+
std::vector<uint8_t> buf_v;
3005+
3006+
for (uint32_t il = 0; il < n_layer; ++il) {
3007+
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
3008+
const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
3009+
3010+
const size_t v_size_el = ggml_type_size(v_l[il]->type);
3011+
const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
3012+
3013+
buf_k.resize(k_size);
3014+
buf_v.resize(v_size);
3015+
3016+
ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
3017+
ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
3018+
3019+
// batch move [i, i+nm) to [id, id+nm)
3020+
// note: cells can move only to a lower index
3021+
for (uint32_t i = 0; i < n_kv; ++i) {
3022+
const uint32_t id = ids[i];
3023+
3024+
if (i == id || id == n_kv) {
3025+
continue;
3026+
}
3027+
3028+
uint32_t nm = 1;
3029+
3030+
while (i + nm < n_kv && ids[i + nm] == id + nm) {
3031+
nm++;
3032+
}
3033+
3034+
// move keys
3035+
{
3036+
const int64_t os = i*k_size_row;
3037+
const int64_t od = id*k_size_row;
3038+
3039+
memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
3040+
}
3041+
3042+
// move values (note: they are transposed)
3043+
{
3044+
const int64_t os = i;
3045+
const int64_t od = id;
3046+
3047+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
3048+
memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
3049+
}
3050+
}
3051+
3052+
i += nm - 1;
3053+
}
3054+
3055+
ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
3056+
ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
3057+
}
3058+
#else
3059+
for (uint32_t i = 0; i < ids.size(); ++i) {
3060+
const uint32_t id = ids[i];
3061+
3062+
if (i == id || id == ids.size()) {
3063+
continue;
3064+
}
3065+
3066+
uint32_t nm = 1;
3067+
3068+
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
3069+
nm++;
3070+
}
3071+
3072+
for (uint32_t il = 0; il < n_layer; ++il) { // NOLINT
3073+
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
3074+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
3075+
3076+
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
3077+
n_embd_k_gqa, nm,
3078+
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
3079+
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*i));
3080+
3081+
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il],
3082+
n_embd_k_gqa, nm,
3083+
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
3084+
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
3085+
3086+
ggml_tensor * view_v_src;
3087+
ggml_tensor * view_v_dst;
3088+
3089+
if (cparams.flash_attn) {
3090+
// NOTE: the V cache is not transposed when using flash attention
3091+
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
3092+
n_embd_v_gqa, nm,
3093+
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
3094+
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
3095+
3096+
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
3097+
n_embd_v_gqa, nm,
3098+
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
3099+
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
3100+
} else {
3101+
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
3102+
nm, n_embd_v_gqa,
3103+
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
3104+
ggml_row_size(kv_self.v_l[il]->type, i));
3105+
3106+
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
3107+
nm, n_embd_v_gqa,
3108+
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
3109+
ggml_row_size(kv_self.v_l[il]->type, id));
3110+
}
3111+
3112+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
3113+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
3114+
}
3115+
3116+
i += nm - 1;
3117+
}
3118+
3119+
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
3120+
#endif
3121+
}
3122+
28203123
ggml_tensor * llama_context_kv_self::build_inp_embd_enc(
28213124
ggml_context * ctx0,
28223125
int32_t n_tokens,

src/llama-context.h

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ class llama_context_kv_self : public llama_context {
378378

379379
virtual void build_attn_kv_store(
380380
ggml_context * ctx0,
381-
ggml_cgraph * graph,
381+
ggml_cgraph * gf,
382382
ggml_tensor * k_cur,
383383
ggml_tensor * v_cur,
384384
int32_t n_tokens,
@@ -387,7 +387,7 @@ class llama_context_kv_self : public llama_context {
387387

388388
virtual ggml_tensor * build_attn_qkv(
389389
ggml_context * ctx0,
390-
ggml_cgraph * graph,
390+
ggml_cgraph * gf,
391391
ggml_tensor * wo,
392392
ggml_tensor * wo_b,
393393
ggml_tensor * q_cur,
@@ -401,6 +401,15 @@ class llama_context_kv_self : public llama_context {
401401
ggml_tensor * kq,
402402
float kq_scale) override;
403403

404+
virtual void build_kv_self_shift(
405+
ggml_context * ctx0,
406+
ggml_cgraph * gf) override;
407+
408+
// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
409+
virtual void build_kv_self_defrag(
410+
ggml_context * ctx0,
411+
ggml_cgraph * gf) override;
412+
404413
// === encoder-decoder ===
405414

406415
// whether we are computing encoder output or decoder output
@@ -443,7 +452,7 @@ class llama_context_kv_self : public llama_context {
443452

444453
virtual ggml_tensor * build_copy_mask_state(
445454
ggml_context * ctx0,
446-
ggml_cgraph * graph,
455+
ggml_cgraph * gf,
447456
ggml_tensor * s,
448457
ggml_tensor * state_copy,
449458
ggml_tensor * state_mask,
@@ -454,7 +463,7 @@ class llama_context_kv_self : public llama_context {
454463

455464
virtual ggml_tensor * build_mamba_layer(
456465
ggml_context * ctx0,
457-
ggml_cgraph * graph,
466+
ggml_cgraph * gf,
458467
ggml_tensor * cur,
459468
ggml_tensor * state_copy,
460469
ggml_tensor * state_mask,
@@ -464,7 +473,7 @@ class llama_context_kv_self : public llama_context {
464473

465474
virtual ggml_tensor * build_rwkv_token_shift_load(
466475
ggml_context * ctx0,
467-
ggml_cgraph * graph,
476+
ggml_cgraph * gf,
468477
ggml_tensor * state_copy,
469478
ggml_tensor * state_mask,
470479
const llama_ubatch & ubatch,
@@ -480,7 +489,7 @@ class llama_context_kv_self : public llama_context {
480489

481490
virtual ggml_tensor * build_rwkv6_time_mix(
482491
ggml_context * ctx0,
483-
ggml_cgraph * graph,
492+
ggml_cgraph * gf,
484493
ggml_tensor * cur,
485494
ggml_tensor * x_prev,
486495
ggml_tensor * state_copy,

0 commit comments

Comments
 (0)