@@ -539,6 +539,8 @@ enum llm_tensor {
539539 LLM_TENSOR_ATTN_Q_B,
540540 LLM_TENSOR_ATTN_KV_A_MQA,
541541 LLM_TENSOR_ATTN_KV_B,
542+ LLM_TENSOR_ATTN_K_B,
543+ LLM_TENSOR_ATTN_V_B,
542544 LLM_TENSOR_ATTN_Q_A_NORM,
543545 LLM_TENSOR_ATTN_KV_A_NORM,
544546 LLM_TENSOR_ATTN_SUB_NORM,
@@ -1203,6 +1205,8 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
12031205 { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
12041206 { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
12051207 { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
1208+ { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
1209+ { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
12061210 { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
12071211 { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
12081212 { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
@@ -2541,6 +2545,8 @@ struct llama_layer {
25412545 struct ggml_tensor * wq_b;
25422546 struct ggml_tensor * wkv_a_mqa;
25432547 struct ggml_tensor * wkv_b;
2548+ struct ggml_tensor * wk_b;
2549+ struct ggml_tensor * wv_b;
25442550 struct ggml_tensor * wq_cross;
25452551 struct ggml_tensor * wk_cross;
25462552 struct ggml_tensor * wv_cross;
@@ -2669,11 +2675,19 @@ struct llama_kv_cache {
26692675 ggml_type type_k = GGML_TYPE_F16;
26702676 ggml_type type_v = GGML_TYPE_F16;
26712677
2678+ ggml_type type_kr = GGML_TYPE_F16;
2679+ ggml_type type_kv = GGML_TYPE_F16;
2680+
26722681 std::vector<llama_kv_cell> cells;
26732682
26742683 std::vector<struct ggml_tensor *> k_l; // per layer
26752684 std::vector<struct ggml_tensor *> v_l;
26762685
2686+ // DeepSeek MLA
2687+ std::vector<struct ggml_tensor *> kr_l; // per layer
2688+ std::vector<struct ggml_tensor *> kv_l;
2689+ std::vector<struct ggml_tensor *> kvt_l;
2690+
26772691 std::vector<struct ggml_context *> ctxs;
26782692 std::vector<ggml_backend_buffer_t> bufs;
26792693
@@ -3132,7 +3146,7 @@ static bool llama_kv_cache_init(
31323146 for (auto & it : buft_layer_count) {
31333147 int n_layers = it.second;
31343148 struct ggml_init_params params = {
3135- /*.mem_size =*/ 2u *n_layers*ggml_tensor_overhead(),
3149+ /*.mem_size =*/ 5u *n_layers*ggml_tensor_overhead(),
31363150 /*.mem_buffer =*/ NULL,
31373151 /*.no_alloc =*/ true,
31383152 };
@@ -3148,6 +3162,11 @@ static bool llama_kv_cache_init(
31483162 cache.k_l.reserve(n_layer);
31493163 cache.v_l.reserve(n_layer);
31503164
3165+ // DeepSeek MLA
3166+ cache.kr_l.reserve(n_layer);
3167+ cache.kv_l.reserve(n_layer);
3168+ cache.kvt_l.reserve(n_layer);
3169+
31513170 for (int i = 0; i < (int) n_layer; i++) {
31523171 const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
31533172 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
@@ -3159,6 +3178,21 @@ static bool llama_kv_cache_init(
31593178 ggml_format_name(v, "cache_v_l%d", i);
31603179 cache.k_l.push_back(k);
31613180 cache.v_l.push_back(v);
3181+
3182+
3183+ // DeepSeek MLA
3184+ const uint32_t n_embd_head_qk_rope = hparams.n_rot;
3185+ const uint32_t kv_lora_rank = hparams.n_lora_kv;
3186+ LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
3187+ ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size);
3188+ ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
3189+ ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
3190+ ggml_format_name(kr, "cache_kr_l%d", i);
3191+ ggml_format_name(kv, "cache_kv_l%d", i);
3192+ ggml_format_name(kvt, "cache_kvt_l%d", i);
3193+ cache.kr_l.push_back(kr);
3194+ cache.kv_l.push_back(kv);
3195+ cache.kvt_l.push_back(kvt);
31623196 }
31633197
31643198 // allocate tensors and initialize the buffers to avoid NaNs in the padding
@@ -7644,6 +7678,8 @@ static bool llm_load_tensors(
76447678
76457679 layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)});
76467680 layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)});
7681+ layer.wk_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 0);
7682+ layer.wv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 0);
76477683 layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd});
76487684
76497685 layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
@@ -13396,31 +13432,31 @@ struct llm_build_context {
1339613432 LLM_NORM_RMS, cb, il);
1339713433 cb(kv_compressed, "kv_compressed", il);
1339813434
13399- // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
13400- struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
13401- cb(kv, "kv", il);
13435+ struct ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)*kv_head);
13436+ cb(kv_cache_view, "kv_cache_view", il);
1340213437
13403- // split into {n_head * n_embd_head_qk_nope, n_tokens}
13404- struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
13405- ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
13406- ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
13407- 0);
13408- cb(k_nope, "k_nope", il);
13438+ // note: storing c^KV in the KV cache
13439+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, kv_compressed, kv_cache_view));
1340913440
13410- // and {n_head * n_embd_head_v, n_tokens}
13411- struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
13412- ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
13413- ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
13414- ggml_row_size(kv->type, (n_embd_head_qk_nope)));
13415- cb(v_states, "v_states", il);
13441+ struct ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), ggml_row_size(kv_self.kv_l[il]->type, kv_head));
13442+ cb(kv_cache_trans_view, "kv_cache_trans_view", il);
1341613443
13417- v_states = ggml_cont(ctx0, v_states);
13418- cb(v_states, "v_states", il );
13444+ // note: storing transposed c^KV in the transposed KV cache
13445+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view) );
1341913446
13420- v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
13421- ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
13422- 0);
13423- cb(v_states, "v_states", il);
13447+ struct ggml_tensor * kv_cache =
13448+ ggml_view_2d(ctx0, kv_self.kv_l[il],
13449+ kv_lora_rank, n_kv,
13450+ ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank),
13451+ 0);
13452+ cb(kv_cache, "kv_cache", il);
13453+
13454+ struct ggml_tensor * kv_cache_trans =
13455+ ggml_view_2d(ctx0, kv_self.kvt_l[il],
13456+ n_kv, kv_lora_rank,
13457+ ggml_row_size(kv_self.kv_l[il]->type, kv_self.size),
13458+ 0);
13459+ cb(kv_cache_trans, "kv_cache_trans", il);
1342413460
1342513461 q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
1342613462 q_pe = ggml_rope_ext(
@@ -13439,15 +13475,74 @@ struct llm_build_context {
1343913475 );
1344013476 cb(k_pe, "k_pe", il);
1344113477
13442- struct ggml_tensor * q_states = ggml_concat (ctx0, q_nope, q_pe, 0 );
13443- cb(q_states , "q_states ", il);
13478+ struct ggml_tensor * kr_cache_view = ggml_view_1d (ctx0, kv_self.kr_l[il], n_tokens*n_embd_head_qk_rope, ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope)*kv_head );
13479+ cb(kr_cache_view , "kr_cache_view ", il);
1344413480
13445- struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
13446- cb(k_states, "k_states", il );
13481+ // note: storing RoPE-ed version of K^R in the KV cache
13482+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_pe, kr_cache_view) );
1344713483
13448- cur = llm_build_kv(ctx0, lctx, kv_self, gf,
13449- model.layers[il].wo, NULL,
13450- k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
13484+ struct ggml_tensor * kr_cache =
13485+ ggml_view_2d(ctx0, kv_self.kr_l[il],
13486+ n_embd_head_qk_rope, n_kv,
13487+ ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope),
13488+ 0);
13489+ cb(kr_cache, "kr_cache", il);
13490+
13491+ struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0);
13492+ cb(wk_b, "wk_b", il);
13493+
13494+ struct ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
13495+ cb(q_nope_perm, "q_nope_perm", il);
13496+
13497+ struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope_perm);
13498+ cb(q_nope2, "q_nope2", il);
13499+
13500+ struct ggml_tensor * q_nope2_perm = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3);
13501+ cb(q_nope2_perm, "q_nope2_perm", il);
13502+
13503+ struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2_perm);
13504+ cb(kq_nope, "kq_nope", il);
13505+
13506+ struct ggml_tensor * q_pe_perm = ggml_permute(ctx0, q_pe, 0, 3, 2, 1);
13507+ cb(q_pe_perm, "q_pe_perm", il);
13508+
13509+ struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe);
13510+ cb(kq_pe, "kq_pe", il);
13511+
13512+ struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe);
13513+ cb(kq, "kq", il);
13514+
13515+ kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
13516+ cb(kq, "kq_perm", il);
13517+
13518+ kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
13519+ cb(kq, "kq_soft_max_ext", il);
13520+
13521+ struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 1, 3);
13522+ cb(kq_perm, "kq_soft_max_ext_perm", il);
13523+
13524+ struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq_perm);
13525+ cb(kqv_compressed, "kqv_compressed", il);
13526+
13527+ kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
13528+ cb(kqv_compressed, "kqv_compressed_perm", il);
13529+
13530+ struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0);
13531+ cb(wv_b, "wv_b", il);
13532+
13533+ struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed);
13534+ cb(kqv, "kqv", il);
13535+
13536+ kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3));
13537+ cb(kqv, "kqv_perm", il);
13538+
13539+ cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0);
13540+ cb(cur, "kqv_2d", il);
13541+
13542+ ggml_build_forward_expand(gf, cur);
13543+
13544+ cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
13545+ cb(cur, "kqv_out", il);
1345113546 }
1345213547
1345313548 if (il == n_layer - 1) {
@@ -17853,6 +17948,24 @@ struct llama_context * llama_new_context_with_model(
1785317948 ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
1785417949 }
1785517950
17951+ {
17952+ size_t memory_size_kr = 0;
17953+ size_t memory_size_kv = 0;
17954+
17955+ for (auto & kr : ctx->kv_self.kr_l) {
17956+ memory_size_kr += ggml_nbytes(kr);
17957+ }
17958+
17959+ for (auto & kv : ctx->kv_self.kv_l) {
17960+ memory_size_kv += ggml_nbytes(kv);
17961+ }
17962+
17963+ LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K^R (%s): %7.2f MiB, c^KV (%s): %7.2f MiB\n", __func__,
17964+ (float)(memory_size_kr + memory_size_kv) / (1024.0f * 1024.0f),
17965+ ggml_type_name(type_k), (float)memory_size_kr / (1024.0f * 1024.0f),
17966+ ggml_type_name(type_k), (float)memory_size_kv / (1024.0f * 1024.0f));
17967+ }
17968+
1785617969 // graph outputs buffer
1785717970 {
1785817971 // resized during inference when a batch uses more outputs
0 commit comments