Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1203,9 +1203,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
// note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla
const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mla->ne[1];

const auto n_tokens = q->ne[1];
const auto n_head = q->ne[2];
const auto n_kv = k->ne[1];
const auto n_embd = q->ne[0];
const auto n_tokens = q->ne[1];
const auto n_head = q->ne[2];

const auto n_kv = k->ne[1];
const auto n_head_kv = k->ne[2];

ggml_tensor * cur;

Expand Down Expand Up @@ -1233,11 +1236,20 @@ ggml_tensor * llm_graph_context::build_attn_mha(

cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
} else {
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);

// note: this op tends to require high floating point range
// while for some models F16 is enough, for others it is not, so we default to F32 here
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
// for MQA (ie: GQA with 1 group) we don't need to use a batched matrix multiply
ggml_tensor * kq = nullptr;
if (ggml_is_contiguous(k) && ggml_is_contiguous(q) && n_head_kv == 1) {
k = ggml_reshape_2d(ctx0, k, n_embd, n_kv);
q = ggml_reshape_2d(ctx0, q, n_embd, n_tokens*n_head);
kq = ggml_mul_mat(ctx0, k, q);
// note: this op tends to require high floating point range while for some models F16 is enough, for others it is not, so we default to F32 here
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
kq = ggml_reshape_3d(ctx0, kq, n_kv, n_tokens, n_head);
} else {
kq = ggml_mul_mat(ctx0, k, q);
// note: this op tends to require high floating point range while for some models F16 is enough, for others it is not, so we default to F32 here
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
}

if (arch == LLM_ARCH_GROK) {
// need to do the following:
Expand Down
17 changes: 12 additions & 5 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10143,6 +10143,10 @@ struct llm_build_deepseek2 : public llm_graph_context {
cb(kv_cmpr, "kv_cmpr", il);

if (is_mla) {
// {n_embd_head_qk_rope, n_tokens, n_head}
q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3);
cb(q_pe, "q_pe_perm", il);

// {n_embd_head_qk_nope, n_tokens, n_head}
q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
cb(q_nope, "q_nope_perm", il);
Expand All @@ -10151,15 +10155,15 @@ struct llm_build_deepseek2 : public llm_graph_context {
ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope);
cb(q_nope_absorbed, "q_nope_absorbed", il);

// {kv_lora_rank, n_head, n_tokens}
q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3);
cb(q_nope_absorbed, "q_nope_absorbed_perm", il);

// {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
// {n_embd_head_qk_rope + kv_lora_rank, n_tokens, n_head}
// note: rope must go first for in-place context shifting in build_rope_shift()
ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0);
cb(Qcur, "Qcur", il);

// {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
Qcur = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
cb(Qcur, "Qcur_perm", il);

kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
cb(kv_cmpr, "kv_cmpr_reshape", il);

Expand All @@ -10171,6 +10175,9 @@ struct llm_build_deepseek2 : public llm_graph_context {
ggml_tensor * Vcur = kv_cmpr;
cb(Vcur, "Vcur", il);

Vcur = ggml_cont(ctx0, Vcur);
cb(Vcur, "Vcur_cont", il);

// note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group)
cur = build_attn(inp_attn, gf,
model.layers[il].wo, NULL,
Expand Down
Loading