Skip to content

Commit 7b66649

Browse files
authored
Fix kq
1 parent 1d63edf commit 7b66649

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/llama-graph.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,15 +1235,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12351235
cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
12361236
} else {
12371237
// for MQA (ie: GQA with 1 group) we don't need to use a batched matrix multiply
1238+
ggml_tensor * kq = nullptr;
12381239
if (ggml_is_contiguous(k) && ggml_is_contiguous(q) && n_head_kv == 1) {
12391240
k = ggml_reshape_2d(ctx0, k, n_embd, n_tokens);
12401241
q = ggml_reshape_2d(ctx0, q, n_embd, n_tokens*n_head);
1241-
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1242+
kq = ggml_mul_mat(ctx0, k, q);
12421243
// note: this op tends to require high floating point range while for some models F16 is enough, for others it is not, so we default to F32 here
12431244
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
12441245
kq = ggml_reshape_3d(ctx0, kq, n_kv, n_tokens, n_head);
12451246
} else {
1246-
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1247+
kq = ggml_mul_mat(ctx0, k, q);
12471248
// note: this op tends to require high floating point range while for some models F16 is enough, for others it is not, so we default to F32 here
12481249
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
12491250
}

0 commit comments

Comments
 (0)