diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 5d0222b981058..1d7400446cc52 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1203,9 +1203,12 @@ ggml_tensor * llm_graph_context::build_attn_mha( // note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mla->ne[1]; - const auto n_tokens = q->ne[1]; - const auto n_head = q->ne[2]; - const auto n_kv = k->ne[1]; + const auto n_embd = q->ne[0]; + const auto n_tokens = q->ne[1]; + const auto n_head = q->ne[2]; + + const auto n_kv = k->ne[1]; + const auto n_head_kv = k->ne[2]; ggml_tensor * cur; @@ -1233,11 +1236,20 @@ ggml_tensor * llm_graph_context::build_attn_mha( cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens); } else { - ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - - // note: this op tends to require high floating point range - // while for some models F16 is enough, for others it is not, so we default to F32 here - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + // for MQA (ie: GQA with 1 group) we don't need to use a batched matrix multiply + ggml_tensor * kq = nullptr; + if (ggml_is_contiguous(k) && ggml_is_contiguous(q) && n_head_kv == 1) { + k = ggml_reshape_2d(ctx0, k, n_embd, n_kv); + q = ggml_reshape_2d(ctx0, q, n_embd, n_tokens*n_head); + kq = ggml_mul_mat(ctx0, k, q); + // note: this op tends to require high floating point range while for some models F16 is enough, for others it is not, so we default to F32 here + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + kq = ggml_reshape_3d(ctx0, kq, n_kv, n_tokens, n_head); + } else { + kq = ggml_mul_mat(ctx0, k, q); + // note: this op tends to require high floating point range while for some models F16 is enough, for others it is not, so we default to F32 here + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } if (arch == LLM_ARCH_GROK) { // need to do the following: diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 248c61748eaa8..ed62191ea8d9d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -10143,6 +10143,10 @@ struct llm_build_deepseek2 : public llm_graph_context { cb(kv_cmpr, "kv_cmpr", il); if (is_mla) { + // {n_embd_head_qk_rope, n_tokens, n_head} + q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3); + cb(q_pe, "q_pe_perm", il); + // {n_embd_head_qk_nope, n_tokens, n_head} q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); cb(q_nope, "q_nope_perm", il); @@ -10151,15 +10155,15 @@ struct llm_build_deepseek2 : public llm_graph_context { ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope); cb(q_nope_absorbed, "q_nope_absorbed", il); - // {kv_lora_rank, n_head, n_tokens} - q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3); - cb(q_nope_absorbed, "q_nope_absorbed_perm", il); - - // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} + // {n_embd_head_qk_rope + kv_lora_rank, n_tokens, n_head} // note: rope must go first for in-place context shifting in build_rope_shift() ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0); cb(Qcur, "Qcur", il); + // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} + Qcur = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + cb(Qcur, "Qcur_perm", il); + kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); cb(kv_cmpr, "kv_cmpr_reshape", il); @@ -10171,6 +10175,9 @@ struct llm_build_deepseek2 : public llm_graph_context { ggml_tensor * Vcur = kv_cmpr; cb(Vcur, "Vcur", il); + Vcur = ggml_cont(ctx0, Vcur); + cb(Vcur, "Vcur_cont", il); + // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group) cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL,