Skip to content

Commit f2fcd2c

Browse files
authored
Add back MQA 2D x 2D optimisation
1 parent 0518461 commit f2fcd2c

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

src/llama-graph.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,9 +1203,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12031203
// note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla
12041204
const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mla->ne[1];
12051205

1206-
const auto n_tokens = q->ne[1];
1207-
const auto n_head = q->ne[2];
1208-
const auto n_kv = k->ne[1];
1206+
const auto n_tokens = q->ne[1];
1207+
const auto n_head = q->ne[2];
1208+
const auto n_kv = k->ne[1];
1209+
const auto n_head_kv = k->ne[2];
12091210

12101211
ggml_tensor * cur;
12111212

@@ -1233,12 +1234,21 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12331234

12341235
cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
12351236
} else {
1237+
// for MQA (ie: GQA with 1 group) we don't need to use a batched matrix multiply
1238+
if (ggml_is_contiguous(q) && n_head_kv == 1) {
1239+
q = ggml_reshape_2d(ctx0, q, n_embd, n_tokens*n_head);
1240+
}
1241+
12361242
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
12371243

12381244
// note: this op tends to require high floating point range
12391245
// while for some models F16 is enough, for others it is not, so we default to F32 here
12401246
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
12411247

1248+
if (ggml_is_contiguous(q) && n_head_kv == 1) {
1249+
kq = ggml_reshape_3d(ctx0, kq, n_kv, n_tokens, n_head);
1250+
}
1251+
12421252
if (arch == LLM_ARCH_GROK) {
12431253
// need to do the following:
12441254
// multiply by attn_output_multiplyer of 0.08838834764831845

0 commit comments

Comments
 (0)