@@ -589,21 +589,11 @@ static struct ggml_tensor * llm_build_kqv(
589589 0 );
590590 cb (v, " v" , il);
591591
592- struct ggml_tensor * padded_v = v;
593- if (n_embd_head_v < n_embd_head_k) {
594- padded_v = ggml_pad (ctx, v, 0 , k->ne [0 ] - v->ne [1 ], 0 , 0 );
595- cb (padded_v, " padded_v" , il);
596- }
597-
598- cur = ggml_flash_attn_ext (ctx, q, k, padded_v, kq_mask, kq_scale, hparams.f_max_alibi_bias ,
592+ cur = ggml_flash_attn_ext (ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias ,
599593 hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0 .0f );
600594
601595 ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
602596
603- if (n_embd_head_v < n_embd_head_k) {
604- cur = ggml_view_1d (ctx, ggml_cont (ctx, cur), n_embd_head_k*n_head, n_tokens);
605- }
606-
607597 cur = ggml_reshape_2d (ctx, cur, n_embd_head_v*n_head, n_tokens);
608598 } else {
609599 struct ggml_tensor * kq = ggml_mul_mat (ctx, k, q);
@@ -9577,8 +9567,8 @@ struct llama_context * llama_init_from_model(
95779567 params.flash_attn = false ;
95789568 }
95799569
9580- if (params.flash_attn && model->hparams .n_embd_head_k < model->hparams .n_embd_head_v ) {
9581- LLAMA_LOG_WARN (" %s: flash_attn requires n_embd_head_k > = n_embd_head_v - forcing off\n " , __func__);
9570+ if (params.flash_attn && model->hparams .n_embd_head_k != model->hparams .n_embd_head_v ) {
9571+ LLAMA_LOG_WARN (" %s: flash_attn requires n_embd_head_k ! = n_embd_head_v - forcing off\n " , __func__);
95829572 params.flash_attn = false ;
95839573 }
95849574
0 commit comments