Skip to content

Commit c0827df

Browse files
committed
fix
1 parent 892bbc6 commit c0827df

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

src/llama.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -604,17 +604,13 @@ static struct ggml_tensor * llm_build_kqv(
604604
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
605605

606606
if (n_embd_head_v < n_embd_head_k) {
607-
LLAMA_LOG_INFO("cur shape: [%ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2]);
608-
cur = ggml_reshape_3d(ctx, cur, n_head, n_tokens, n_embd_head_v_out);
609-
LLAMA_LOG_INFO("cur shape: [%ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2]);
610-
cur = ggml_cont(ctx, ggml_view_3d(ctx, cur, n_head, n_tokens, n_embd_head_v,
611-
ggml_element_size(cur) * n_head,
612-
ggml_element_size(cur) * n_head * n_tokens,
607+
cur = ggml_reshape_2d(ctx, ggml_cont(ctx, cur), n_embd_head_v_out*n_head, n_tokens);
608+
cur = ggml_cont(ctx, ggml_view_2d(ctx, ggml_cont(ctx, cur), n_embd_head_v*n_head, n_tokens,
609+
ggml_element_size(cur) * n_embd_head_v_out,
613610
0));
614-
LLAMA_LOG_INFO("cur shape: [%ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2]);
611+
} else {
612+
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
615613
}
616-
617-
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
618614
} else {
619615
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
620616
cb(kq, "kq", il);

0 commit comments

Comments
 (0)