Skip to content

Commit 959a793

Browse files
authored
Added missing n_embd and fixed n_kv bug
1 parent 7b66649 commit 959a793

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/llama-graph.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1203,8 +1203,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12031203
// note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla
12041204
const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mla->ne[1];
12051205

1206+
const auto n_embd = q->ne[0];
12061207
const auto n_tokens = q->ne[1];
12071208
const auto n_head = q->ne[2];
1209+
12081210
const auto n_kv = k->ne[1];
12091211
const auto n_head_kv = k->ne[2];
12101212

@@ -1237,7 +1239,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12371239
// for MQA (ie: GQA with 1 group) we don't need to use a batched matrix multiply
12381240
ggml_tensor * kq = nullptr;
12391241
if (ggml_is_contiguous(k) && ggml_is_contiguous(q) && n_head_kv == 1) {
1240-
k = ggml_reshape_2d(ctx0, k, n_embd, n_tokens);
1242+
k = ggml_reshape_2d(ctx0, k, n_embd, n_kv);
12411243
q = ggml_reshape_2d(ctx0, q, n_embd, n_tokens*n_head);
12421244
kq = ggml_mul_mat(ctx0, k, q);
12431245
// note: this op tends to require high floating point range while for some models F16 is enough, for others it is not, so we default to F32 here

0 commit comments

Comments
 (0)