@@ -1203,9 +1203,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12031203 // note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla
12041204 const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne [1 ] : v->ne [0 ] : v_mla->ne [1 ];
12051205
1206- const auto n_tokens = q->ne [1 ];
1207- const auto n_head = q->ne [2 ];
1208- const auto n_kv = k->ne [1 ];
1206+ const auto n_tokens = q->ne [1 ];
1207+ const auto n_head = q->ne [2 ];
1208+ const auto n_kv = k->ne [1 ];
1209+ const auto n_head_kv = k->ne [2 ];
12091210
12101211 ggml_tensor * cur;
12111212
@@ -1233,12 +1234,21 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12331234
12341235 cur = ggml_reshape_2d (ctx0, cur, n_embd_head_v*n_head, n_tokens);
12351236 } else {
1237+ // for MQA (ie: GQA with 1 group) we don't need to use a batched matrix multiply
1238+ if (ggml_is_contiguous (q) && n_head_kv == 1 ) {
1239+ q = ggml_reshape_2d (ctx0, q, n_embd, n_tokens*n_head);
1240+ }
1241+
12361242 ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
12371243
12381244 // note: this op tends to require high floating point range
12391245 // while for some models F16 is enough, for others it is not, so we default to F32 here
12401246 ggml_mul_mat_set_prec (kq, GGML_PREC_F32);
12411247
1248+ if (ggml_is_contiguous (q) && n_head_kv == 1 ) {
1249+ kq = ggml_reshape_3d (ctx0, kq, n_kv, n_tokens, n_head);
1250+ }
1251+
12421252 if (arch == LLM_ARCH_GROK) {
12431253 // need to do the following:
12441254 // multiply by attn_output_multiplyer of 0.08838834764831845
0 commit comments