@@ -1478,7 +1478,7 @@ ggml_tensor * llm_graph_context::build_attn(
14781478ggml_tensor * llm_graph_context::build_attn_mla (
14791479 llm_graph_input_attn_kv_unified * inp,
14801480 ggml_cgraph * gf,
1481- ggml_tensor * wv_decompress ,
1481+ ggml_tensor * wv_b ,
14821482 ggml_tensor * wo,
14831483 ggml_tensor * q_cur,
14841484 ggml_tensor * k_cur,
@@ -1497,8 +1497,8 @@ ggml_tensor * llm_graph_context::build_attn_mla(
14971497 const auto kv_lora_rank = hparams.n_lora_kv ;
14981498
14991499 // note: deepseek with MLA option converts into MQA with larger n_ebed (ie: GQA with 1 group)
1500- const int64_t n_embd_k_compressed = kv_lora_rank + hparams.n_rot ;
1501- const int64_t n_embd_v_compressed = kv_lora_rank;
1500+ const int64_t n_embd_k_cmpr = kv_lora_rank + hparams.n_rot ;
1501+ const int64_t n_embd_v_cmpr = kv_lora_rank;
15021502
15031503 // note: this is the smaller n_ebed what we get after decompression
15041504 const int64_t n_embd_head_v = hparams.n_embd_head_v ;
@@ -1514,17 +1514,17 @@ ggml_tensor * llm_graph_context::build_attn_mla(
15141514 GGML_ASSERT (kv_self->size == n_ctx);
15151515
15161516 ggml_tensor * k_cache_view = ggml_view_1d (ctx0, kv_self->k_l [il],
1517- n_tokens*n_embd_k_compressed ,
1518- ggml_row_size (kv_self->k_l [il]->type , n_embd_k_compressed )*kv_head);
1517+ n_tokens*n_embd_k_cmpr ,
1518+ ggml_row_size (kv_self->k_l [il]->type , n_embd_k_cmpr )*kv_head);
15191519 // cb(k_cache_view, "k_cache_view", il);
15201520
15211521 // note: storing RoPE-ed version of K in the KV cache
15221522 ggml_build_forward_expand (gf, ggml_cpy (ctx0, k_cur, k_cache_view));
15231523
1524- v_cur = ggml_reshape_2d (ctx0, v_cur, n_embd_v_compressed , n_tokens);
1524+ v_cur = ggml_reshape_2d (ctx0, v_cur, n_embd_v_cmpr , n_tokens);
15251525
15261526 ggml_tensor * v_cache_view = ggml_view_2d (ctx0, kv_self->v_l [il],
1527- n_tokens, n_embd_v_compressed ,
1527+ n_tokens, n_embd_v_cmpr ,
15281528 ( n_ctx)*ggml_element_size (kv_self->v_l [il]),
15291529 (kv_head)*ggml_element_size (kv_self->v_l [il]));
15301530
@@ -1543,34 +1543,34 @@ ggml_tensor * llm_graph_context::build_attn_mla(
15431543
15441544 const auto n_kv = kv_self->n ;
15451545
1546- ggml_tensor * k_compressed = ggml_view_2d (ctx0, kv_self->k_l [il],
1547- n_embd_k_compressed , n_kv,
1548- ggml_row_size (kv_self->k_l [il]->type , n_embd_k_compressed ),
1546+ ggml_tensor * k_cmpr = ggml_view_2d (ctx0, kv_self->k_l [il],
1547+ n_embd_k_cmpr , n_kv,
1548+ ggml_row_size (kv_self->k_l [il]->type , n_embd_k_cmpr ),
15491549 0 );
1550- cb (k_compressed , " k_compressed " , il);
1550+ cb (k_cmpr , " k_cmpr " , il);
15511551
1552- struct ggml_tensor * v_compressed_trans = ggml_view_2d (ctx0, kv_self->v_l [il],
1553- n_kv, n_embd_v_compressed ,
1552+ struct ggml_tensor * v_cmpr_trans = ggml_view_2d (ctx0, kv_self->v_l [il],
1553+ n_kv, n_embd_v_cmpr ,
15541554 ggml_element_size (kv_self->v_l [il])*n_ctx,
15551555 0 );
1556- cb (v_compressed_trans , " v_compressed_trans " , il);
1556+ cb (v_cmpr_trans , " v_cmpr_trans " , il);
15571557
1558- ggml_tensor * q_compressed = ggml_view_2d (ctx0, q_cur,
1559- n_embd_k_compressed , n_tokens*n_head,
1560- ggml_row_size (q_cur->type , n_embd_k_compressed ),
1558+ ggml_tensor * q_cmpr = ggml_view_2d (ctx0, q_cur,
1559+ n_embd_k_cmpr , n_tokens*n_head,
1560+ ggml_row_size (q_cur->type , n_embd_k_cmpr ),
15611561 0 );
1562- cb (q_compressed , " q_compressed " , il);
1562+ cb (q_cmpr , " q_cmpr " , il);
15631563
1564- ggml_tensor * kq = ggml_mul_mat (ctx0, k_compressed, q_compressed );
1565- cb (kq , " kq " , il);
1564+ ggml_tensor * kq_cmpr = ggml_mul_mat (ctx0, k_cmpr, q_cmpr );
1565+ cb (kq_cmpr , " kq_cmpr " , il);
15661566
1567- kq = ggml_view_3d (ctx0, kq , n_kv, n_tokens, n_head,
1568- ggml_row_size (kq ->type , n_kv),
1569- ggml_row_size (kq ->type , n_kv)*n_tokens,
1567+ kq_cmpr = ggml_view_3d (ctx0, kq_cmpr , n_kv, n_tokens, n_head,
1568+ ggml_row_size (kq_cmpr ->type , n_kv),
1569+ ggml_row_size (kq_cmpr ->type , n_kv)*n_tokens,
15701570 0 );
1571- cb (kq , " kq_view" , il);
1571+ cb (kq_cmpr , " kq_view" , il);
15721572
1573- ggml_tensor * kq_soft_max = ggml_soft_max_ext (ctx0, kq , kq_mask, kq_scale, hparams.f_max_alibi_bias );
1573+ ggml_tensor * kq_soft_max = ggml_soft_max_ext (ctx0, kq_cmpr , kq_mask, kq_scale, hparams.f_max_alibi_bias );
15741574 cb (kq_soft_max, " kq_soft_max" , il);
15751575
15761576 kq_soft_max = ggml_view_2d (ctx0, kq_soft_max,
@@ -1579,24 +1579,24 @@ ggml_tensor * llm_graph_context::build_attn_mla(
15791579 0 );
15801580 cb (kq_soft_max, " kq_soft_max_view" , il);
15811581
1582- ggml_tensor * kqv_compressed = ggml_mul_mat (ctx0, v_compressed_trans , kq_soft_max);
1583- cb (kqv_compressed , " kqv_compressed ," , il);
1582+ ggml_tensor * kqv_cmpr = ggml_mul_mat (ctx0, v_cmpr_trans , kq_soft_max);
1583+ cb (kqv_cmpr , " kqv_cmpr ," , il);
15841584
1585- kqv_compressed = ggml_view_3d (ctx0, kqv_compressed ,
1586- n_embd_v_compressed , n_tokens, n_head,
1587- ggml_row_size (kqv_compressed ->type , n_embd_v_compressed ),
1588- ggml_row_size (kqv_compressed ->type , n_embd_v_compressed )*n_tokens,
1585+ kqv_cmpr = ggml_view_3d (ctx0, kqv_cmpr ,
1586+ n_embd_v_cmpr , n_tokens, n_head,
1587+ ggml_row_size (kqv_cmpr ->type , n_embd_v_cmpr ),
1588+ ggml_row_size (kqv_cmpr ->type , n_embd_v_cmpr )*n_tokens,
15891589 0 );
1590- cb (kqv_compressed , " kqv_compressed_view " , il);
1590+ cb (kqv_cmpr , " kqv_cmpr_view " , il);
15911591
1592- ggml_tensor * wv_decompress_view = ggml_view_3d (ctx0, wv_decompress ,
1593- n_embd_v_compressed , n_embd_head_v, n_head,
1594- ggml_row_size (wv_decompress ->type , n_embd_v_compressed ),
1595- ggml_row_size (wv_decompress ->type , n_embd_v_compressed )*n_embd_head_v,
1592+ ggml_tensor * wv_b_view = ggml_view_3d (ctx0, wv_b ,
1593+ n_embd_v_cmpr , n_embd_head_v, n_head,
1594+ ggml_row_size (wv_b ->type , n_embd_v_cmpr ),
1595+ ggml_row_size (wv_b ->type , n_embd_v_cmpr )*n_embd_head_v,
15961596 0 );
1597- cb (wv_decompress_view , " wv_decompress_view " , il);
1597+ cb (wv_b_view , " wv_b_view " , il);
15981598
1599- ggml_tensor * kqv = ggml_mul_mat (ctx0, wv_decompress_view, kqv_compressed );
1599+ ggml_tensor * kqv = ggml_mul_mat (ctx0, wv_b_view, kqv_cmpr );
16001600 cb (kqv, " kqv" , il);
16011601
16021602 kqv = ggml_permute (ctx0, kqv, 0 , 2 , 1 , 3 );
0 commit comments