@@ -595,10 +595,9 @@ static struct ggml_tensor * llm_build_kqv(
595595 padded_v = ggml_pad (ctx, v, 0 , k->ne [0 ] - v->ne [1 ], 0 , 0 );
596596 cb (padded_v, " padded_v" , il);
597597 n_embd_head_v_out = n_embd_head_k;
598- padded_v = ggml_cont (ctx, padded_v);
599598 }
600599
601- cur = ggml_flash_attn_ext (ctx, q, k, padded_v, kq_mask, kq_scale, hparams.f_max_alibi_bias ,
600+ cur = ggml_flash_attn_ext (ctx, q, k, ggml_cont (ctx, padded_v) , kq_mask, kq_scale, hparams.f_max_alibi_bias ,
602601 hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0 .0f );
603602
604603 LLAMA_LOG_INFO (" kq_scale: %f\n " , kq_scale);
@@ -614,12 +613,13 @@ static struct ggml_tensor * llm_build_kqv(
614613 ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
615614
616615 if (n_embd_head_v < n_embd_head_k) {
616+ cur = ggml_reshape_2d (ctx, ggml_cont (ctx, cur), n_embd_head_v_out*n_head, n_tokens);
617617 cur = ggml_cont (ctx, ggml_view_2d (ctx, ggml_cont (ctx, cur), n_embd_head_v*n_head, n_tokens,
618618 ggml_element_size (cur) * n_embd_head_v_out,
619619 0 ));
620+ } else {
621+ cur = ggml_reshape_2d (ctx, cur, n_embd_head_v*n_head, n_tokens);
620622 }
621-
622- cur = ggml_reshape_2d (ctx, cur, n_embd_head_v*n_head, n_tokens);
623623 } else {
624624 struct ggml_tensor * kq = ggml_mul_mat (ctx, k, q);
625625 cb (kq, " kq" , il);
0 commit comments