@@ -1091,22 +1091,25 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10911091 ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
10921092
10931093 if (v_mla) {
1094- #if 0
1095- // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
1096- // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
1097- cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
1098- cur = ggml_mul_mat(ctx0, v_mla, cur);
1099- #else
1100- // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
1101- // The permutations are noops and only change how the tensor data is interpreted.
1102- cur = ggml_permute (ctx0, cur, 0 , 2 , 1 , 3 );
1103- cur = ggml_mul_mat (ctx0, v_mla, cur);
1104- cur = ggml_permute (ctx0, cur, 0 , 2 , 1 , 3 );
1105- cur = ggml_cont (ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
1106- #endif
1094+ // To "decompress" from MQA back to MHA, v_mla can be either be applied as:
1095+ // 1. A matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
1096+ // - The code is optimized for dimensions 0 and 1 being large, so this is ineffient.
1097+ // 2. A matrix-matrix multiplication with n_tokens in dimension 1.
1098+ // - The added cost of the cont means that (1) is still more effeicent for small batches.
1099+ if (n_tokens < 32 ) {
1100+ cur = ggml_reshape_4d (ctx0, cur, v_mla->ne [0 ], 1 , n_head, n_tokens);
1101+ cur = ggml_mul_mat (ctx0, v_mla, cur);
1102+ cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens);
1103+ } else {
1104+ cur = ggml_permute (ctx0, cur, 0 , 2 , 1 , 3 );
1105+ cur = ggml_mul_mat (ctx0, v_mla, cur);
1106+ cur = ggml_permute (ctx0, cur, 0 , 2 , 1 , 3 );
1107+ cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens);
1108+ }
1109+ } else {
1110+ cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens);
11071111 }
11081112
1109- cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens);
11101113 } else {
11111114 ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
11121115
0 commit comments