@@ -604,17 +604,13 @@ static struct ggml_tensor * llm_build_kqv(
604604 ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
605605
606606 if (n_embd_head_v < n_embd_head_k) {
607- LLAMA_LOG_INFO (" cur shape: [%ld, %ld, %ld]\n " , cur->ne [0 ], cur->ne [1 ], cur->ne [2 ]);
608- cur = ggml_reshape_3d (ctx, cur, n_head, n_tokens, n_embd_head_v_out);
609- LLAMA_LOG_INFO (" cur shape: [%ld, %ld, %ld]\n " , cur->ne [0 ], cur->ne [1 ], cur->ne [2 ]);
610- cur = ggml_cont (ctx, ggml_view_3d (ctx, cur, n_head, n_tokens, n_embd_head_v,
611- ggml_element_size (cur) * n_head,
612- ggml_element_size (cur) * n_head * n_tokens,
607+ cur = ggml_reshape_2d (ctx, ggml_cont (ctx, cur), n_embd_head_v_out*n_head, n_tokens);
608+ cur = ggml_cont (ctx, ggml_view_2d (ctx, ggml_cont (ctx, cur), n_embd_head_v*n_head, n_tokens,
609+ ggml_element_size (cur) * n_embd_head_v_out,
613610 0 ));
614- LLAMA_LOG_INFO (" cur shape: [%ld, %ld, %ld]\n " , cur->ne [0 ], cur->ne [1 ], cur->ne [2 ]);
611+ } else {
612+ cur = ggml_reshape_2d (ctx, cur, n_embd_head_v*n_head, n_tokens);
615613 }
616-
617- cur = ggml_reshape_2d (ctx, cur, n_embd_head_v*n_head, n_tokens);
618614 } else {
619615 struct ggml_tensor * kq = ggml_mul_mat (ctx, k, q);
620616 cb (kq, " kq" , il);
0 commit comments