@@ -7794,9 +7794,7 @@ static void llm_build_kv_store(
77947794 cb(k_cache_view, "k_cache_view", il);
77957795
77967796 // note: storing RoPE-ed version of K in the KV cache
7797- ggml_tensor * tmp = ggml_cpy(ctx, k_cur, k_cache_view);
7798- tmp->kv_cache_flag = GGML_KV_CACHE_FLAG_K;
7799- ggml_build_forward_expand(graph, tmp);
7797+ ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
78007798
78017799 assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
78027800
@@ -7814,9 +7812,7 @@ static void llm_build_kv_store(
78147812 v_cur = ggml_transpose(ctx, v_cur);
78157813 }
78167814 cb(v_cache_view, "v_cache_view", il);
7817- tmp=ggml_cpy(ctx, v_cur, v_cache_view);
7818- tmp->kv_cache_flag = GGML_KV_CACHE_FLAG_V;
7819- ggml_build_forward_expand(graph, tmp);
7815+ ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
78207816}
78217817
78227818static struct ggml_tensor * llm_build_norm(
@@ -14607,43 +14603,41 @@ static int llama_decode_internal(
1460714603 }
1460814604 lctx.cached_graph.gf = gf;
1460914605
14610- if(ggml_use_cached_graph(lctx.sched)) {
14611-
14612- // Temporarily store KV cache parameters that will need updated in cached graph.
14606+ // Update K and V cache parameters in cached graph.
14607+ if(gf != nullptr && gf->nodes != nullptr && ggml_use_cached_graph(lctx.sched)) {
14608+
1461314609 const struct llama_hparams & hparams = model.hparams;
14614- const int64_t n_layer = hparams.n_layer;
1461514610 const int64_t kv_head = kv_self.head;
14616- std::vector<void *> k_cache_ptrs;
14617- std::vector<void *> v_cache_ptrs;
14618- for (int il = 0; il < n_layer; ++il) {
14619- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
14620- const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
14621- ggml_tensor * tmp_tensor = kv_self.k_l[il];
14622- size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
14623- k_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14624- tmp_tensor = kv_self.v_l[il];
14625- if (cparams.flash_attn) {
14626- tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
14627- } else {
14628- tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
14629- }
14630- v_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14631- }
14632-
14633- // Update KV cache parameters in cached graph.
14634- int k_count = 0;
14635- int v_count = 0;
14636- if(gf != nullptr && gf->nodes != nullptr){
14637- for (int i = 0; i < gf->n_nodes; i++) {
14638- ggml_tensor * node = gf->nodes[i];
14639- if (node->op == GGML_OP_CPY) {
14640- if (node->kv_cache_flag == GGML_KV_CACHE_FLAG_K) {
14641- node->src[1]->data = k_cache_ptrs[k_count++];
14642- }
14643- if (node->kv_cache_flag == GGML_KV_CACHE_FLAG_V) {
14644- node->src[1]->data = v_cache_ptrs[v_count++];
14611+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
14612+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
14613+
14614+ for (int i = 0; i < gf->n_nodes; i++) {
14615+ ggml_tensor * node = gf->nodes[i];
14616+ if (node->op == GGML_OP_CPY) {
14617+
14618+ // K cache
14619+ const char* k_prefix = "k_cache_view-";
14620+ if (strncmp(node->src[1]->name, k_prefix, strlen(k_prefix)) == 0) {
14621+ int il = atoi(node->src[1]->name + strlen(k_prefix)); // Layer index from name
14622+ ggml_tensor * tmp_tensor = kv_self.k_l[il];
14623+ size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
14624+ node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset;
14625+ }
14626+
14627+ // V cache
14628+ const char* v_prefix = "v_cache_view-";
14629+ if (strncmp(node->src[1]->name, v_prefix, strlen(v_prefix)) == 0) {
14630+ int il = atoi(node->src[1]->name + strlen(v_prefix)); // Layer index from name
14631+ ggml_tensor * tmp_tensor = kv_self.v_l[il];
14632+ size_t tmp_offset;
14633+ if (cparams.flash_attn) {
14634+ tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
14635+ } else {
14636+ tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
1464514637 }
14638+ node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset;
1464614639 }
14640+
1464714641 }
1464814642 }
1464914643
0 commit comments