@@ -1203,8 +1203,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12031203 // note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla
12041204 const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne [1 ] : v->ne [0 ] : v_mla->ne [1 ];
12051205
1206+ const auto n_embd = q->ne [0 ];
12061207 const auto n_tokens = q->ne [1 ];
12071208 const auto n_head = q->ne [2 ];
1209+
12081210 const auto n_kv = k->ne [1 ];
12091211 const auto n_head_kv = k->ne [2 ];
12101212
@@ -1237,7 +1239,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12371239 // for MQA (ie: GQA with 1 group) we don't need to use a batched matrix multiply
12381240 ggml_tensor * kq = nullptr ;
12391241 if (ggml_is_contiguous (k) && ggml_is_contiguous (q) && n_head_kv == 1 ) {
1240- k = ggml_reshape_2d (ctx0, k, n_embd, n_tokens );
1242+ k = ggml_reshape_2d (ctx0, k, n_embd, n_kv );
12411243 q = ggml_reshape_2d (ctx0, q, n_embd, n_tokens*n_head);
12421244 kq = ggml_mul_mat (ctx0, k, q);
12431245 // note: this op tends to require high floating point range while for some models F16 is enough, for others it is not, so we default to F32 here
0 commit comments