Skip to content

Commit 00fd137

Browse files
committed
fix
1 parent ef0b5c4 commit 00fd137

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/llama.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -595,10 +595,9 @@ static struct ggml_tensor * llm_build_kqv(
595595
padded_v = ggml_pad(ctx, v, 0, k->ne[0] - v->ne[1], 0, 0);
596596
cb(padded_v, "padded_v", il);
597597
n_embd_head_v_out = n_embd_head_k;
598-
padded_v = ggml_cont(ctx, padded_v);
599598
}
600599

601-
cur = ggml_flash_attn_ext(ctx, q, k, padded_v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
600+
cur = ggml_flash_attn_ext(ctx, q, k, ggml_cont(ctx, padded_v), kq_mask, kq_scale, hparams.f_max_alibi_bias,
602601
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
603602

604603
LLAMA_LOG_INFO("kq_scale: %f\n", kq_scale);
@@ -614,12 +613,13 @@ static struct ggml_tensor * llm_build_kqv(
614613
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
615614

616615
if (n_embd_head_v < n_embd_head_k) {
616+
cur = ggml_reshape_2d(ctx, ggml_cont(ctx, cur), n_embd_head_v_out*n_head, n_tokens);
617617
cur = ggml_cont(ctx, ggml_view_2d(ctx, ggml_cont(ctx, cur), n_embd_head_v*n_head, n_tokens,
618618
ggml_element_size(cur) * n_embd_head_v_out,
619619
0));
620+
} else {
621+
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
620622
}
621-
622-
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
623623
} else {
624624
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
625625
cb(kq, "kq", il);

0 commit comments

Comments
 (0)