@@ -62,6 +62,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
6262
6363 cells.resize (kv_size);
6464
65+ const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0 );
66+
6567 for (uint32_t il = 0 ; il < hparams.n_layer ; il++) {
6668 if (filter && !filter (il)) {
6769 LLAMA_LOG_DEBUG (" %s: layer %3d: skipped\n " , __func__, il);
@@ -93,7 +95,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
9395 ggml_tensor * v;
9496
9597 k = ggml_new_tensor_2d (ctx, type_k, n_embd_k_gqa, kv_size);
96- v = ggml_new_tensor_2d (ctx, type_v, n_embd_v_gqa, kv_size);
98+
99+ // note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache
100+ v = ggml_new_tensor_2d (ctx, type_v, n_embd_v_gqa, !is_mla || v_trans ? kv_size : 0 );
97101
98102 ggml_format_name (k, " cache_k_l%d" , il);
99103 ggml_format_name (v, " cache_v_l%d" , il);
@@ -700,7 +704,9 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
700704}
701705
702706bool llama_kv_cache_unified::get_can_shift () const {
703- return true ;
707+ const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0 );
708+
709+ return !is_mla || v_trans; // TODO: allow context shifting for MLA with flash attention;
704710}
705711
706712uint32_t llama_kv_cache_unified::get_size () const {
@@ -733,12 +739,25 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
733739 auto * v = layers[ikv].v ;
734740
735741 if (!v_trans) {
736- // note: v->nb[1] <= v->nb[2]
737- return ggml_view_3d (ctx, v,
738- hparams.n_embd_head_v , hparams.n_head_kv (il), n_kv,
739- ggml_row_size (v->type , hparams.n_embd_head_v ), // v->nb[1]
740- ggml_row_size (v->type , hparams.n_embd_v_gqa (il)), // v->nb[2]
741- 0 );
742+ const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0 );
743+
744+ if (!is_mla) {
745+ // note: v->nb[1] <= v->nb[2]
746+ return ggml_view_3d (ctx, v,
747+ hparams.n_embd_head_v , hparams.n_head_kv (il), n_kv,
748+ ggml_row_size (v->type , hparams.n_embd_head_v ), // v->nb[1]
749+ ggml_row_size (v->type , hparams.n_embd_v_gqa (il)), // v->nb[2]
750+ 0 );
751+ } else {
752+ auto * k = layers[ikv].k ;
753+
754+ // note: v->nb[1] == v->nb[2] for MLA as transforms into MQA
755+ return ggml_view_3d (ctx, k,
756+ hparams.n_embd_head_v , hparams.n_head_kv (il), n_kv,
757+ ggml_row_size (k->type , hparams.n_embd_head_k ), // v->nb[1]
758+ ggml_row_size (k->type , hparams.n_embd_k_gqa (il)), // v->nb[2]
759+ hparams.n_embd_head_k - hparams.n_embd_head_v ); // offset by n_rot elements
760+ }
742761 }
743762
744763 // note: v->nb[1] > v->nb[2]
0 commit comments