@@ -3180,6 +3180,10 @@ static bool llama_kv_cache_init(
31803180 for (int i = 0; i < (int) n_layer; i++) {
31813181 const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
31823182 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
3183+ const uint32_t n_head = hparams.n_head(i);
3184+ const uint32_t n_head_kv = hparams.n_head_kv(i);
3185+ const uint32_t n_embd_head_k= hparams.n_embd_head_k;
3186+
31833187
31843188 struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
31853189 ggml_tensor * k;
@@ -3201,7 +3205,8 @@ static bool llama_kv_cache_init(
32013205 const uint32_t kv_lora_rank = hparams.n_lora_kv;
32023206 LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
32033207#if MLA_USE_TRANSPOSED_CACHE
3204- ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
3208+ ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size);
3209+ //ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
32053210#else
32063211 ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_v, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
32073212#endif
@@ -3215,7 +3220,10 @@ static bool llama_kv_cache_init(
32153220 n_mla++;
32163221 }
32173222 else {
3218- k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
3223+ //printf("Creating cache tensors:\n");
3224+ //printf("n_embd_k_gqa = %d, kv_size = %d, n_head = %d, n_head_kv = %d, n_embd_head_k = %d\n", (int)n_embd_k_gqa, (int)kv_size, (int)n_head, (int)n_head_kv, (int)n_embd_head_k);
3225+ //k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
3226+ k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size);
32193227 v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
32203228 ggml_format_name(k, "cache_k_l%d", i);
32213229 ggml_format_name(v, "cache_v_l%d", i);
@@ -8285,11 +8293,20 @@ static void llm_build_kv_store(
82858293 const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
82868294 const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
82878295
8296+ const int64_t n_head = hparams.n_head(il);
8297+ const int64_t n_head_kv = hparams.n_head_kv(il);
8298+ const int64_t n_embd_head_k = hparams.n_embd_head_k;
8299+ const int64_t n_embd_head_v = hparams.n_embd_head_v;
8300+
82888301 GGML_ASSERT(kv.size == n_ctx);
82898302
8290- struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
8291- (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
8292- cb(k_cache_view, "k_cache_view", il);
8303+ //struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
8304+ // (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
8305+ //cb(k_cache_view, "k_cache_view", il);
8306+
8307+ auto k_row_size = ggml_row_size(kv.k_l[il]->type, n_embd_head_k);
8308+ ggml_tensor * k_cache_view = ggml_view_2d(ctx, kv.k_l[il], n_embd_head_k, n_tokens*n_head_kv,
8309+ k_row_size, k_row_size*n_head_kv*kv_head);
82938310
82948311 // note: storing RoPE-ed version of K in the KV cache
82958312 ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
@@ -8708,7 +8725,7 @@ static struct ggml_tensor * llm_build_kqv(
87088725 struct ggml_tensor * k =
87098726 ggml_view_3d(ctx, kv.k_l[il],
87108727 n_embd_head_k, n_kv, n_head_kv,
8711- ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
8728+ ggml_row_size(kv.k_l[il]->type, n_embd_head_k)*n_head_kv, // n_embd_k_gqa),
87128729 ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
87138730 0);
87148731 cb(k, "k", il);
@@ -13509,8 +13526,9 @@ struct llm_build_context {
1350913526 ggml_tensor * kvr = ggml_concat(ctx0, kv_compressed, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), 0);
1351013527 cb(kvr, "kvr", il);
1351113528
13512- ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*(kv_lora_rank + n_embd_head_qk_rope),
13513- ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope)*kv_head);
13529+ auto row_size = ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope);
13530+ ggml_tensor * kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_tokens,
13531+ row_size, row_size*kv_head);
1351413532 ggml_build_forward_expand(gf, ggml_cpy(ctx0, kvr, kv_cache_view));
1351513533 ggml_tensor * kv_cache = ggml_view_2d(ctx0, kv_self.kv_l[il],
1351613534 kv_lora_rank + n_embd_head_qk_rope, n_kv,
0 commit comments