Skip to content

Commit 9e2b21f

Browse files
ikawrakowIwan Kawrakow
andauthored
DeepSeek: enable option to merge Q and K tensors (#941)
* Merge Q and K for DeepSeek * Formatting --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent ba1a753 commit 9e2b21f

File tree

5 files changed

+116
-55
lines changed

5 files changed

+116
-55
lines changed

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ enum llm_tensor {
262262
LLM_TENSOR_ATTN_Q_A,
263263
LLM_TENSOR_ATTN_Q_B,
264264
LLM_TENSOR_ATTN_KV_A_MQA,
265+
LLM_TENSOR_ATTN_KQ_A_MQA,
265266
LLM_TENSOR_ATTN_KV_B,
266267
LLM_TENSOR_ATTN_K_B,
267268
LLM_TENSOR_ATTN_V_B,

src/llama-build-context.cpp

Lines changed: 81 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5935,6 +5935,7 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
59355935
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
59365936
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
59375937
const uint32_t kv_lora_rank = hparams.n_lora_kv;
5938+
const uint32_t q_lora_rank = hparams.n_lora_q;
59385939

59395940
struct ggml_tensor * cur;
59405941
struct ggml_tensor * inpL;
@@ -5961,68 +5962,96 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
59615962

59625963
// self_attention
59635964
{
5964-
struct ggml_tensor * q = NULL;
5965-
if (!is_lite) {
5966-
// {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens}
5967-
q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
5968-
cb(q, "q", il);
5969-
5970-
q = llm_build_norm(ctx0, q, hparams, model.layers[il].attn_q_a_norm, NULL, LLM_NORM_RMS, cb, il);
5971-
cb(q, "q", il);
5972-
5973-
// {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
5974-
q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
5975-
cb(q, "q", il);
5976-
} else {
5977-
q = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
5978-
cb(q, "q", il);
5965+
ggml_tensor * q = nullptr;
5966+
ggml_tensor * kv_rope_compressed = nullptr;
5967+
ggml_tensor * q_rope;
5968+
ggml_tensor * q_nope;
5969+
ggml_tensor * k_rope;
5970+
ggml_tensor * kv_compressed;
5971+
if (model.layers[il].wkq_a_mqa) {
5972+
auto mqa = ggml_mul_mat(ctx0, model.layers[il].wkq_a_mqa, cur);
5973+
cb(mqa, "mqa", il);
5974+
size_t qnb1;
5975+
if (!is_lite) {
5976+
q = ggml_view_2d(ctx0, mqa, q_lora_rank, n_tokens, mqa->nb[1], 0);
5977+
q = llm_build_norm(ctx0, q, hparams, model.layers[il].attn_q_a_norm, NULL, LLM_NORM_RMS, cb, il);
5978+
q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
5979+
qnb1 = q->nb[1];
5980+
cb(q, "q", il);
5981+
kv_rope_compressed = ggml_view_2d(ctx0, mqa, kv_lora_rank + n_embd_head_qk_rope, n_tokens, mqa->nb[1],
5982+
q_lora_rank*ggml_element_size(mqa));
5983+
} else {
5984+
q = ggml_view_2d(ctx0, mqa, n_embd_k_gqa, n_tokens, mqa->nb[1], 0);
5985+
kv_rope_compressed = ggml_view_2d(ctx0, mqa, kv_lora_rank + n_embd_head_qk_rope, n_tokens, mqa->nb[1],
5986+
n_embd_k_gqa*ggml_element_size(mqa));
5987+
qnb1 = mqa->nb[1];
5988+
}
5989+
q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
5990+
ggml_row_size(q->type, hparams.n_embd_head_k), qnb1, 0);
5991+
q_rope = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
5992+
ggml_row_size(q->type, hparams.n_embd_head_k), qnb1, ggml_row_size(q->type, n_embd_head_qk_nope));
5993+
k_rope = ggml_view_3d(ctx0, kv_rope_compressed, n_embd_head_qk_rope, 1, n_tokens,
5994+
mqa->nb[1], mqa->nb[1], ggml_row_size(kv_rope_compressed->type, kv_lora_rank));
5995+
kv_compressed = ggml_view_2d(ctx0, kv_rope_compressed, kv_lora_rank, n_tokens, mqa->nb[1], 0);
59795996
}
5997+
else {
5998+
if (!is_lite) {
5999+
q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
6000+
cb(q, "q", il);
59806001

5981-
// split into {n_head * n_embd_head_qk_nope, n_tokens}
5982-
struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
5983-
ggml_row_size(q->type, hparams.n_embd_head_k),
5984-
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
5985-
0);
5986-
cb(q_nope, "q_nope", il);
6002+
kv_rope_compressed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
6003+
cb(kv_rope_compressed, "kv_rope_compressed", il);
59876004

5988-
// and {n_head * n_embd_head_qk_rope, n_tokens}
5989-
struct ggml_tensor * q_rope = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
5990-
ggml_row_size(q->type, hparams.n_embd_head_k),
5991-
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
5992-
ggml_row_size(q->type, n_embd_head_qk_nope));
5993-
cb(q_rope, "q_rope", il);
6005+
ggml_build_forward_expand(gf, q);
6006+
ggml_build_forward_expand(gf, kv_rope_compressed);
59946007

5995-
q_rope = ggml_rope_ext(
5996-
ctx0, q_rope, inp_pos, nullptr,
5997-
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
5998-
ext_factor, attn_factor_scaled, beta_fast, beta_slow
5999-
);
6000-
cb(q_rope, "q_rope", il);
6008+
q = llm_build_norm(ctx0, q, hparams, model.layers[il].attn_q_a_norm, NULL, LLM_NORM_RMS, cb, il);
6009+
cb(q, "q", il);
6010+
6011+
q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
6012+
cb(q, "q", il);
6013+
} else {
6014+
q = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
6015+
cb(q, "q", il);
60016016

6002-
// {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
6003-
struct ggml_tensor * kv_rope_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
6004-
cb(kv_rope_compresseed, "kv_rope_compresseed", il);
6017+
kv_rope_compressed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
6018+
cb(kv_rope_compressed, "kv_rope_compressed", il);
60056019

6006-
// and {n_embd_head_qk_rope, n_tokens}
6007-
struct ggml_tensor * k_rope = ggml_view_3d(ctx0, kv_rope_compresseed, n_embd_head_qk_rope, 1, n_tokens,
6008-
kv_rope_compresseed->nb[1],
6009-
kv_rope_compresseed->nb[1],
6010-
ggml_row_size(kv_rope_compresseed->type, kv_lora_rank));
6020+
ggml_build_forward_expand(gf, q);
6021+
ggml_build_forward_expand(gf, kv_rope_compressed);
6022+
}
6023+
6024+
q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
6025+
ggml_row_size(q->type, hparams.n_embd_head_k),
6026+
ggml_row_size(q->type, hparams.n_embd_head_k * n_head), 0);
6027+
6028+
q_rope = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
6029+
ggml_row_size(q->type, hparams.n_embd_head_k),
6030+
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
6031+
ggml_row_size(q->type, n_embd_head_qk_nope));
6032+
6033+
k_rope = ggml_view_3d(ctx0, kv_rope_compressed, n_embd_head_qk_rope, 1, n_tokens,
6034+
kv_rope_compressed->nb[1],
6035+
kv_rope_compressed->nb[1],
6036+
ggml_row_size(kv_rope_compressed->type, kv_lora_rank));
6037+
6038+
kv_compressed = ggml_view_2d(ctx0, kv_rope_compressed, kv_lora_rank, n_tokens,
6039+
kv_rope_compressed->nb[1], 0);
6040+
}
6041+
cb(q_nope, "q_nope", il);
6042+
cb(q_rope, "q_rope", il);
60116043
cb(k_rope, "k_rope", il);
6044+
cb(kv_compressed, "kv_compressed", il);
60126045

6013-
// shared RoPE key
6014-
k_rope = ggml_rope_ext(
6015-
ctx0, k_rope, inp_pos, nullptr,
6046+
q_rope = ggml_rope_ext(ctx0, q_rope, inp_pos, nullptr,
60166047
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6017-
ext_factor, attn_factor_scaled, beta_fast, beta_slow
6018-
);
6019-
cb(k_rope, "k_rope", il);
6048+
ext_factor, attn_factor_scaled, beta_fast, beta_slow);
6049+
cb(q_rope, "q_rope", il);
60206050

6021-
// split into {kv_lora_rank, n_tokens}
6022-
struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_rope_compresseed, kv_lora_rank, n_tokens,
6023-
kv_rope_compresseed->nb[1],
6024-
0);
6025-
cb(kv_compressed, "kv_compressed", il);
6051+
k_rope = ggml_rope_ext(ctx0, k_rope, inp_pos, nullptr,
6052+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6053+
ext_factor, attn_factor_scaled, beta_fast, beta_slow);
6054+
cb(k_rope, "k_rope", il);
60266055

60276056
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, model.layers[il].attn_kv_a_norm, NULL, LLM_NORM_RMS, cb, il);
60286057
cb(kv_compressed, "kv_compressed", il);

src/llama-load-tensors.cpp

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1645,14 +1645,43 @@ bool create_tensors_helper::create_deepseek2_tensors(const LLM_TN & tn) {
16451645

16461646
layer.attn_kv_a_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank});
16471647

1648+
bool merged = false;
1649+
if (ml.merge_qkv) {
1650+
auto q_name = is_lite ? tn(LLM_TENSOR_ATTN_Q, "weight", i) : tn(LLM_TENSOR_ATTN_Q_A, "weight", i);
1651+
auto k_name = tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i);
1652+
auto wq = ml.require_tensor_meta(q_name.c_str());
1653+
auto wk = ml.require_tensor_meta(k_name.c_str());
1654+
GGML_ASSERT(wq && wk);
1655+
if (wq->type == wk->type) {
1656+
GGML_ASSERT(wq->ne[0] == wk->ne[0]);
1657+
layer.wkq_a_mqa = ggml_new_tensor_2d(ctx_split, wq->type, wq->ne[0], wq->ne[1] + wk->ne[1]);
1658+
snprintf(layer.wkq_a_mqa->name, GGML_MAX_NAME, "blk.%d.attn_qk_a_mqa.weight", i);
1659+
if (is_lite) {
1660+
layer.wq = ml.create_tensor_as_view(ctx_split, layer.wkq_a_mqa, q_name.c_str(), { wq->ne[0], wq->ne[1] }, 0);
1661+
} else {
1662+
layer.wq_a = ml.create_tensor_as_view(ctx_split, layer.wkq_a_mqa, q_name.c_str(), { wq->ne[0], wq->ne[1] }, 0);
1663+
}
1664+
layer.wkv_a_mqa = ml.create_tensor_as_view(ctx_split, layer.wkq_a_mqa, k_name.c_str(), { wk->ne[0], wk->ne[1] }, wq->ne[1]*wq->nb[1]);
1665+
merged = true;
1666+
use_mmap_buffer = false;
1667+
printf("============== Merged %s (%ld x %ld) and %s (%ld x %ld)\n", q_name.c_str(),
1668+
wq->ne[0], wq->ne[1], k_name.c_str(), wk->ne[0], wk->ne[1]);
1669+
}
1670+
}
1671+
16481672
if (!is_lite) {
1649-
layer.wq_a = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank});
1673+
if (!merged) {
1674+
layer.wq_a = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank});
1675+
}
16501676
layer.wq_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k});
1651-
} else {
1677+
} else if (!merged) {
16521678
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa});
16531679
}
16541680

1655-
layer.wkv_a_mqa = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i),{n_embd, kv_lora_rank + (n_embd_head_qk_rope)});
1681+
if (!merged) {
1682+
layer.wkv_a_mqa = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i),{n_embd, kv_lora_rank + (n_embd_head_qk_rope)});
1683+
}
1684+
16561685
layer.wkv_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i),
16571686
{kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, llama_model_loader::TENSOR_NOT_REQUIRED);
16581687
if (!layer.wkv_b) {

src/llama-model.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,7 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
805805
{ LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" },
806806
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
807807
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
808+
{ LLM_TENSOR_ATTN_KQ_A_MQA, "blk.%d.attn_kq_a_mqa" },
808809
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
809810
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
810811
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },

src/llama-model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ struct llama_layer {
160160
struct ggml_tensor * wq_a = nullptr;
161161
struct ggml_tensor * wq_b = nullptr;
162162
struct ggml_tensor * wkv_a_mqa = nullptr;
163+
struct ggml_tensor * wkq_a_mqa = nullptr;
163164
struct ggml_tensor * wkv_b = nullptr;
164165
struct ggml_tensor * wk_b = nullptr;
165166
struct ggml_tensor * wv_b = nullptr;

0 commit comments

Comments
 (0)