Skip to content

Commit 1d63edf

Browse files
authored
Check k is contiguous and reshape it
1 parent f2fcd2c commit 1d63edf

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

src/llama-graph.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,18 +1235,17 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12351235
cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
12361236
} else {
12371237
// for MQA (ie: GQA with 1 group) we don't need to use a batched matrix multiply
1238-
if (ggml_is_contiguous(q) && n_head_kv == 1) {
1238+
if (ggml_is_contiguous(k) && ggml_is_contiguous(q) && n_head_kv == 1) {
1239+
k = ggml_reshape_2d(ctx0, k, n_embd, n_tokens);
12391240
q = ggml_reshape_2d(ctx0, q, n_embd, n_tokens*n_head);
1240-
}
1241-
1242-
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1243-
1244-
// note: this op tends to require high floating point range
1245-
// while for some models F16 is enough, for others it is not, so we default to F32 here
1246-
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
1247-
1248-
if (ggml_is_contiguous(q) && n_head_kv == 1) {
1241+
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1242+
// note: this op tends to require high floating point range while for some models F16 is enough, for others it is not, so we default to F32 here
1243+
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
12491244
kq = ggml_reshape_3d(ctx0, kq, n_kv, n_tokens, n_head);
1245+
} else {
1246+
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1247+
// note: this op tends to require high floating point range while for some models F16 is enough, for others it is not, so we default to F32 here
1248+
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
12501249
}
12511250

12521251
if (arch == LLM_ARCH_GROK) {

0 commit comments

Comments
 (0)