Skip to content

Commit f8f5be1

Browse files
committed
fix
1 parent 87f1435 commit f8f5be1

File tree

1 file changed

+3
-13
lines changed

1 file changed

+3
-13
lines changed

src/llama.cpp

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -589,21 +589,11 @@ static struct ggml_tensor * llm_build_kqv(
589589
0);
590590
cb(v, "v", il);
591591

592-
struct ggml_tensor * padded_v = v;
593-
if (n_embd_head_v < n_embd_head_k) {
594-
padded_v = ggml_pad(ctx, v, 0, k->ne[0] - v->ne[1], 0, 0);
595-
cb(padded_v, "padded_v", il);
596-
}
597-
598-
cur = ggml_flash_attn_ext(ctx, q, k, padded_v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
592+
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
599593
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
600594

601595
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
602596

603-
if (n_embd_head_v < n_embd_head_k) {
604-
cur = ggml_view_1d(ctx, ggml_cont(ctx, cur), n_embd_head_k*n_head, n_tokens);
605-
}
606-
607597
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
608598
} else {
609599
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
@@ -9577,8 +9567,8 @@ struct llama_context * llama_init_from_model(
95779567
params.flash_attn = false;
95789568
}
95799569

9580-
if (params.flash_attn && model->hparams.n_embd_head_k < model->hparams.n_embd_head_v) {
9581-
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k >= n_embd_head_v - forcing off\n", __func__);
9570+
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
9571+
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k != n_embd_head_v - forcing off\n", __func__);
95829572
params.flash_attn = false;
95839573
}
95849574

0 commit comments

Comments
 (0)