@@ -1235,18 +1235,17 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12351235 cur = ggml_reshape_2d (ctx0, cur, n_embd_head_v*n_head, n_tokens);
12361236 } else {
12371237 // for MQA (ie: GQA with 1 group) we don't need to use a batched matrix multiply
1238- if (ggml_is_contiguous (q) && n_head_kv == 1 ) {
1238+ if (ggml_is_contiguous (k) && ggml_is_contiguous (q) && n_head_kv == 1 ) {
1239+ k = ggml_reshape_2d (ctx0, k, n_embd, n_tokens);
12391240 q = ggml_reshape_2d (ctx0, q, n_embd, n_tokens*n_head);
1240- }
1241-
1242- ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
1243-
1244- // note: this op tends to require high floating point range
1245- // while for some models F16 is enough, for others it is not, so we default to F32 here
1246- ggml_mul_mat_set_prec (kq, GGML_PREC_F32);
1247-
1248- if (ggml_is_contiguous (q) && n_head_kv == 1 ) {
1241+ ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
1242+ // note: this op tends to require high floating point range while for some models F16 is enough, for others it is not, so we default to F32 here
1243+ ggml_mul_mat_set_prec (kq, GGML_PREC_F32);
12491244 kq = ggml_reshape_3d (ctx0, kq, n_kv, n_tokens, n_head);
1245+ } else {
1246+ ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
1247+ // note: this op tends to require high floating point range while for some models F16 is enough, for others it is not, so we default to F32 here
1248+ ggml_mul_mat_set_prec (kq, GGML_PREC_F32);
12501249 }
12511250
12521251 if (arch == LLM_ARCH_GROK) {
0 commit comments