Skip to content

Commit e162e47

Browse files
committed
Merge remote-tracking branch 'fairydreaming/deepseek2-mla-exp' into tmp
2 parents 0d4ff95 + 7654331 commit e162e47

File tree

10 files changed

+203
-30
lines changed

10 files changed

+203
-30
lines changed

convert_hf_to_gguf.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4136,6 +4136,28 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
41364136
else:
41374137
return []
41384138

4139+
if name.endswith("kv_b_proj.weight"):
4140+
name_kb = name.replace("kv_b_proj", "k_b_proj")
4141+
name_vb = name.replace("kv_b_proj", "v_b_proj")
4142+
4143+
n_head_kv = self.hparams["num_key_value_heads"]
4144+
v_head_dim = self.hparams["v_head_dim"]
4145+
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
4146+
4147+
assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
4148+
4149+
kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
4150+
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
4151+
k_b = k_b.transpose(1, 2)
4152+
k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim)
4153+
v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1])
4154+
4155+
return [
4156+
(self.map_tensor_name(name), data_torch),
4157+
(self.map_tensor_name(name_kb), k_b),
4158+
(self.map_tensor_name(name_vb), v_b)
4159+
]
4160+
41394161
return [(self.map_tensor_name(name), data_torch)]
41404162

41414163
def prepare_tensors(self):

gguf-py/gguf/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,8 @@ class MODEL_TENSOR(IntEnum):
356356
ATTN_Q_B = auto()
357357
ATTN_KV_A_MQA = auto()
358358
ATTN_KV_B = auto()
359+
ATTN_K_B = auto()
360+
ATTN_V_B = auto()
359361
ATTN_Q_A_NORM = auto()
360362
ATTN_KV_A_NORM = auto()
361363
FFN_SUB_NORM = auto()
@@ -543,6 +545,8 @@ class MODEL_TENSOR(IntEnum):
543545
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
544546
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
545547
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
548+
MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b",
549+
MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b",
546550
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
547551
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
548552
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
@@ -1333,6 +1337,8 @@ class MODEL_TENSOR(IntEnum):
13331337
MODEL_TENSOR.ATTN_Q_B,
13341338
MODEL_TENSOR.ATTN_KV_A_MQA,
13351339
MODEL_TENSOR.ATTN_KV_B,
1340+
MODEL_TENSOR.ATTN_K_B,
1341+
MODEL_TENSOR.ATTN_V_B,
13361342
MODEL_TENSOR.ATTN_Q_A_NORM,
13371343
MODEL_TENSOR.ATTN_KV_A_NORM,
13381344
MODEL_TENSOR.ATTN_OUT,

gguf-py/gguf/tensor_mapping.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,14 @@ class TensorNameMap:
586586
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
587587
),
588588

589+
MODEL_TENSOR.ATTN_K_B: (
590+
"model.layers.{bid}.self_attn.k_b_proj", # deepseek2
591+
),
592+
593+
MODEL_TENSOR.ATTN_V_B: (
594+
"model.layers.{bid}.self_attn.v_b_proj", # deepseek2
595+
),
596+
589597
MODEL_TENSOR.ATTN_Q_A_NORM: (
590598
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
591599
),

src/llama-arch.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -999,6 +999,8 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
999999
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
10001000
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
10011001
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
1002+
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
1003+
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
10021004
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
10031005
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
10041006
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
@@ -1333,6 +1335,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
13331335
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
13341336
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
13351337
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1338+
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1339+
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
13361340
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
13371341
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
13381342
{LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
@@ -1350,6 +1354,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
13501354
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
13511355
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
13521356
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1357+
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1358+
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
13531359
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
13541360
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
13551361
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},

src/llama-arch.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,8 @@ enum llm_tensor {
277277
LLM_TENSOR_ATTN_Q_B,
278278
LLM_TENSOR_ATTN_KV_A_MQA,
279279
LLM_TENSOR_ATTN_KV_B,
280+
LLM_TENSOR_ATTN_K_B,
281+
LLM_TENSOR_ATTN_V_B,
280282
LLM_TENSOR_ATTN_Q_A_NORM,
281283
LLM_TENSOR_ATTN_KV_A_NORM,
282284
LLM_TENSOR_ATTN_SUB_NORM,

src/llama-kv-cache.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ bool llama_kv_cache_init(
5353
auto it = ctx_map.find(buft);
5454
if (it == ctx_map.end()) {
5555
struct ggml_init_params params = {
56-
/*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
56+
/*.mem_size =*/ size_t(5u*n_layer*ggml_tensor_overhead()),
5757
/*.mem_buffer =*/ NULL,
5858
/*.no_alloc =*/ true,
5959
};
@@ -71,6 +71,11 @@ bool llama_kv_cache_init(
7171
cache.k_l.reserve(n_layer);
7272
cache.v_l.reserve(n_layer);
7373

74+
// DeepSeek MLA
75+
cache.kr_l.reserve(n_layer);
76+
cache.kv_l.reserve(n_layer);
77+
cache.kvt_l.reserve(n_layer);
78+
7479
for (int i = 0; i < n_layer; i++) {
7580
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
7681
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
@@ -97,6 +102,20 @@ bool llama_kv_cache_init(
97102
ggml_format_name(v, "cache_v_l%d", i);
98103
cache.k_l.push_back(k);
99104
cache.v_l.push_back(v);
105+
106+
// DeepSeek MLA
107+
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
108+
const uint32_t kv_lora_rank = hparams.n_lora_kv;
109+
LLAMA_LOG_DEBUG("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
110+
ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size);
111+
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
112+
ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
113+
ggml_format_name(kr, "cache_kr_l%d", i);
114+
ggml_format_name(kv, "cache_kv_l%d", i);
115+
ggml_format_name(kvt, "cache_kvt_l%d", i);
116+
cache.kr_l.push_back(kr);
117+
cache.kv_l.push_back(kv);
118+
cache.kvt_l.push_back(kvt);
100119
}
101120

102121
// allocate tensors and initialize the buffers to avoid NaNs in the padding

src/llama-kv-cache.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,19 @@ struct llama_kv_cache {
4949
ggml_type type_k = GGML_TYPE_F16;
5050
ggml_type type_v = GGML_TYPE_F16;
5151

52+
ggml_type type_kr = GGML_TYPE_F16;
53+
ggml_type type_kv = GGML_TYPE_F16;
54+
5255
std::vector<llama_kv_cell> cells;
5356

5457
std::vector<struct ggml_tensor *> k_l; // per layer
5558
std::vector<struct ggml_tensor *> v_l;
5659

60+
// DeepSeek MLA
61+
std::vector<struct ggml_tensor *> kr_l; // per layer
62+
std::vector<struct ggml_tensor *> kv_l;
63+
std::vector<struct ggml_tensor *> kvt_l;
64+
5765
std::vector<ggml_context_ptr> ctxs;
5866
std::vector<ggml_backend_buffer_ptr> bufs;
5967

src/llama-model.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2912,6 +2912,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
29122912

29132913
layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
29142914
layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
2915+
layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 0);
2916+
layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 0);
29152917
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0);
29162918

29172919
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);

src/llama-model.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ struct llama_layer {
161161
struct ggml_tensor * wq_b = nullptr;
162162
struct ggml_tensor * wkv_a_mqa = nullptr;
163163
struct ggml_tensor * wkv_b = nullptr;
164+
struct ggml_tensor * wk_b = nullptr;
165+
struct ggml_tensor * wv_b = nullptr;
164166
struct ggml_tensor * wq_cross = nullptr;
165167
struct ggml_tensor * wk_cross = nullptr;
166168
struct ggml_tensor * wv_cross = nullptr;

src/llama.cpp

Lines changed: 127 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6404,6 +6404,10 @@ struct llm_build_context {
64046404
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
64056405
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
64066406

6407+
// whether to use n_tokens as the matrix dimension during multiplication or n_head
6408+
// n_tokens is higher during prompt processing, this allows to optimize for this case
6409+
bool pp_opt = n_tokens > n_head;
6410+
64076411
for (int il = 0; il < n_layer; ++il) {
64086412
struct ggml_tensor * inpSA = inpL;
64096413

@@ -6472,33 +6476,33 @@ struct llm_build_context {
64726476
LLM_NORM_RMS, cb, il);
64736477
cb(kv_compressed, "kv_compressed", il);
64746478

6475-
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
6476-
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
6477-
cb(kv, "kv", il);
6479+
struct ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)*kv_head);
6480+
cb(kv_cache_view, "kv_cache_view", il);
64786481

6479-
// split into {n_head * n_embd_head_qk_nope, n_tokens}
6480-
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
6481-
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
6482-
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
6483-
0);
6484-
cb(k_nope, "k_nope", il);
6482+
// note: storing c^KV in the KV cache
6483+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, kv_compressed, kv_cache_view));
64856484

6486-
// and {n_head * n_embd_head_v, n_tokens}
6487-
struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
6488-
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
6489-
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
6490-
ggml_row_size(kv->type, (n_embd_head_qk_nope)));
6491-
cb(v_states, "v_states", il);
6485+
struct ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), ggml_row_size(kv_self.kv_l[il]->type, kv_head));
6486+
cb(kv_cache_trans_view, "kv_cache_trans_view", il);
64926487

6493-
v_states = ggml_cont(ctx0, v_states);
6494-
cb(v_states, "v_states", il);
6488+
// note: storing transposed c^KV in the transposed KV cache
6489+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view));
64956490

6496-
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
6497-
ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
6498-
0);
6499-
cb(v_states, "v_states", il);
6491+
struct ggml_tensor * kv_cache =
6492+
ggml_view_2d(ctx0, kv_self.kv_l[il],
6493+
kv_lora_rank, n_kv,
6494+
ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank),
6495+
0);
6496+
cb(kv_cache, "kv_cache", il);
65006497

6501-
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
6498+
struct ggml_tensor * kv_cache_trans =
6499+
ggml_view_2d(ctx0, kv_self.kvt_l[il],
6500+
n_kv, kv_lora_rank,
6501+
ggml_row_size(kv_self.kv_l[il]->type, kv_self.size),
6502+
0);
6503+
cb(kv_cache_trans, "kv_cache_trans", il);
6504+
6505+
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
65026506
q_pe = ggml_rope_ext(
65036507
ctx0, q_pe, inp_pos, nullptr,
65046508
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -6515,15 +6519,91 @@ struct llm_build_context {
65156519
);
65166520
cb(k_pe, "k_pe", il);
65176521

6518-
struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
6519-
cb(q_states, "q_states", il);
6522+
struct ggml_tensor * kr_cache_view = ggml_view_1d(ctx0, kv_self.kr_l[il], n_tokens*n_embd_head_qk_rope, ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope)*kv_head);
6523+
cb(kr_cache_view, "kr_cache_view", il);
65206524

6521-
struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
6522-
cb(k_states, "k_states", il);
6525+
// note: storing RoPE-ed version of K^R in the KV cache
6526+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_pe, kr_cache_view));
65236527

6524-
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
6525-
model.layers[il].wo, NULL,
6526-
k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
6528+
struct ggml_tensor * kr_cache =
6529+
ggml_view_2d(ctx0, kv_self.kr_l[il],
6530+
n_embd_head_qk_rope, n_kv,
6531+
ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope),
6532+
0);
6533+
cb(kr_cache, "kr_cache", il);
6534+
6535+
struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0);
6536+
cb(wk_b, "wk_b", il);
6537+
6538+
q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
6539+
cb(q_nope, "q_nope_perm", il);
6540+
6541+
struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope);
6542+
cb(q_nope2, "q_nope2", il);
6543+
6544+
if (!pp_opt) {
6545+
q_nope2 = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3);
6546+
cb(q_nope2, "q_nope2_perm", il);
6547+
}
6548+
6549+
struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2);
6550+
cb(kq_nope, "kq_nope", il);
6551+
6552+
if (!pp_opt) {
6553+
kq_nope = ggml_permute(ctx0, kq_nope, 0, 2, 1, 3);
6554+
cb(kq_nope, "kq_nope_perm", il);
6555+
}
6556+
6557+
if (pp_opt) {
6558+
q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3);
6559+
cb(q_pe, "q_pe_perm", il);
6560+
}
6561+
6562+
struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe);
6563+
cb(kq_pe, "kq_pe", il);
6564+
6565+
if (!pp_opt) {
6566+
kq_pe = ggml_permute(ctx0, kq_pe, 0, 2, 1, 3);
6567+
cb(kq_pe, "kq_pe_perm", il);
6568+
}
6569+
6570+
struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe);
6571+
cb(kq, "kq", il);
6572+
6573+
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
6574+
cb(kq, "kq_soft_max_ext", il);
6575+
6576+
if (!pp_opt) {
6577+
kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
6578+
cb(kq, "kq_soft_max_ext_perm", il);
6579+
}
6580+
6581+
struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
6582+
cb(kqv_compressed, "kqv_compressed", il);
6583+
6584+
if (!pp_opt) {
6585+
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 3, 1);
6586+
cb(kqv_compressed, "kqv_compressed_perm", il);
6587+
}
6588+
6589+
struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0);
6590+
cb(wv_b, "wv_b", il);
6591+
6592+
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed);
6593+
cb(kqv, "kqv", il);
6594+
6595+
if (pp_opt) {
6596+
kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3));
6597+
cb(kqv, "kqv_perm", il);
6598+
}
6599+
6600+
cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0);
6601+
cb(cur, "kqv_2d", il);
6602+
6603+
ggml_build_forward_expand(gf, cur);
6604+
6605+
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
6606+
cb(cur, "kqv_out", il);
65276607
}
65286608

65296609
if (il == n_layer - 1) {
@@ -9768,6 +9848,24 @@ struct llama_context * llama_init_from_model(
97689848
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
97699849
}
97709850

9851+
{
9852+
size_t memory_size_kr = 0;
9853+
size_t memory_size_kv = 0;
9854+
9855+
for (auto & kr : ctx->kv_self.kr_l) {
9856+
memory_size_kr += ggml_nbytes(kr);
9857+
}
9858+
9859+
for (auto & kv : ctx->kv_self.kv_l) {
9860+
memory_size_kv += ggml_nbytes(kv);
9861+
}
9862+
9863+
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K^R (%s): %7.2f MiB, c^KV (%s): %7.2f MiB\n", __func__,
9864+
(float)(memory_size_kr + memory_size_kv) / (1024.0f * 1024.0f),
9865+
ggml_type_name(type_k), (float)memory_size_kr / (1024.0f * 1024.0f),
9866+
ggml_type_name(type_k), (float)memory_size_kv / (1024.0f * 1024.0f));
9867+
}
9868+
97719869
// graph outputs buffer
97729870
{
97739871
// resized during inference when a batch uses more outputs

0 commit comments

Comments
 (0)