Skip to content

Commit 892bbc6

Browse files
committed
fix
1 parent 848cade commit 892bbc6

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/llama.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -605,11 +605,11 @@ static struct ggml_tensor * llm_build_kqv(
605605

606606
if (n_embd_head_v < n_embd_head_k) {
607607
LLAMA_LOG_INFO("cur shape: [%ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2]);
608-
cur = ggml_reshape_3d(ctx, cur, n_head, n_embd_head_v_out, n_tokens);
608+
cur = ggml_reshape_3d(ctx, cur, n_head, n_tokens, n_embd_head_v_out);
609609
LLAMA_LOG_INFO("cur shape: [%ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2]);
610-
cur = ggml_cont(ctx, ggml_view_3d(ctx, cur, n_head, n_embd_head_v, n_tokens,
610+
cur = ggml_cont(ctx, ggml_view_3d(ctx, cur, n_head, n_tokens, n_embd_head_v,
611611
ggml_element_size(cur) * n_head,
612-
ggml_element_size(cur) * n_embd_head_v_out * n_head,
612+
ggml_element_size(cur) * n_head * n_tokens,
613613
0));
614614
LLAMA_LOG_INFO("cur shape: [%ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2]);
615615
}

0 commit comments

Comments
 (0)