Skip to content

Commit b518439

Browse files
committed
Fix for KV save and load
1 parent 1343d66 commit b518439

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

src/llama-kv-cache.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,7 @@ bool llama_cross_kv_cache_init(struct llama_cross_kv_cache & cache,
726726
bool offload) {
727727
const struct llama_hparams & hparams = model.hparams;
728728
const int32_t n_layer = hparams.n_layer;
729+
cache.cache_filled = false;
729730

730731
// create a context for each buffer type
731732
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;

src/llama.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -739,16 +739,19 @@ static struct ggml_tensor * llm_build_cross_kv(
739739
const float cross_attn_scale = 1.0f / sqrtf(float(qcur->ne[0] / num_heads));
740740
// Only add the computation of K and V if
741741
// the cache doesn't already have the data
742-
//if (!kv.cache_filled) {
743-
// // Add computation of K, V to the graph
744-
// ggml_build_forward_expand(graph, kcur);
745-
// ggml_build_forward_expand(graph, vcur);
746-
// // Copy K and V to the cross KV cache
747-
// ggml_build_forward_expand(graph, ggml_cpy(ctx, kcur, kv.k_l[il]));
748-
// ggml_build_forward_expand(graph, ggml_cpy(ctx, vcur, kv.v_l[il]));
749-
//}
750-
struct ggml_tensor * k = kcur;
751-
struct ggml_tensor * v = vcur;
742+
if (!kv.cache_filled) {
743+
// Add computation of K, V to the graph
744+
ggml_build_forward_expand(graph, kcur);
745+
ggml_build_forward_expand(graph, vcur);
746+
// Copy K and V to the cross KV cache
747+
ggml_build_forward_expand(graph, ggml_cpy(ctx, kcur, kv.k_l[il]));
748+
ggml_build_forward_expand(graph, ggml_cpy(ctx, vcur, kv.v_l[il]));
749+
if (il == 0) {
750+
printf("Copying KV values to the cross KV cache\n");
751+
}
752+
}
753+
struct ggml_tensor * k = kv.k_l[il];
754+
struct ggml_tensor * v = kv.v_l[il];
752755
// Compute cross attention score
753756
struct ggml_tensor * q = ggml_reshape_4d(ctx, qcur, qcur->ne[0] / num_heads,
754757
num_heads, qcur->ne[1], qcur->ne[2]);
@@ -8320,7 +8323,6 @@ struct llm_build_context {
83208323

83218324
inpSA = ggml_add(ctx0, inpSA, cur);
83228325
}
8323-
lctx.kv_cross.cache_filled = true;
83248326

83258327
cur = ggml_rms_norm(ctx0, inpSA, hparams.f_norm_rms_eps);
83268328
cur = ggml_mul(ctx0, cur, model.output_norm);
@@ -8936,6 +8938,10 @@ static int llama_decode_impl(
89368938
}
89378939
}
89388940

8941+
if (llama_model_has_cross_kv(&lctx.model)) {
8942+
lctx.kv_cross.cache_filled = true;
8943+
}
8944+
89398945
// update the kv ring buffer
89408946
{
89418947
kv_self.head += ubatch.n_tokens;

0 commit comments

Comments
 (0)