Skip to content

Commit c6067bf

Browse files
authored
Added test using n_tokens < 32 to start with
1 parent e54b394 commit c6067bf

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

src/llama-graph.cpp

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,22 +1091,25 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10911091
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
10921092

10931093
if (v_mla) {
1094-
#if 0
1095-
// v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
1096-
// However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
1097-
cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
1098-
cur = ggml_mul_mat(ctx0, v_mla, cur);
1099-
#else
1100-
// It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
1101-
// The permutations are noops and only change how the tensor data is interpreted.
1102-
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1103-
cur = ggml_mul_mat(ctx0, v_mla, cur);
1104-
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1105-
cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
1106-
#endif
1094+
// To "decompress" from MQA back to MHA, v_mla can be either be applied as:
1095+
// 1. A matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
1096+
// - The code is optimized for dimensions 0 and 1 being large, so this is ineffient.
1097+
// 2. A matrix-matrix multiplication with n_tokens in dimension 1.
1098+
// - The added cost of the cont means that (1) is still more effeicent for small batches.
1099+
if (n_tokens < 32) {
1100+
cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
1101+
cur = ggml_mul_mat(ctx0, v_mla, cur);
1102+
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1103+
} else {
1104+
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1105+
cur = ggml_mul_mat(ctx0, v_mla, cur);
1106+
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1107+
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1108+
}
1109+
} else {
1110+
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
11071111
}
11081112

1109-
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
11101113
} else {
11111114
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
11121115

0 commit comments

Comments
 (0)