@@ -8025,6 +8025,8 @@ struct llm_build_bert : public llm_graph_context {
80258025 }
80268026
80278027 if (model.layers[il].attn_q_norm) {
8028+ Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head*n_head, n_tokens);
8029+
80288030 Qcur = build_norm(Qcur,
80298031 model.layers[il].attn_q_norm,
80308032 model.layers[il].attn_q_norm_b,
@@ -8034,6 +8036,8 @@ struct llm_build_bert : public llm_graph_context {
80348036 }
80358037
80368038 if (model.layers[il].attn_k_norm) {
8039+ Kcur = ggml_reshape_2d(ctx0, Kcur, n_embd_head*n_head_kv, n_tokens);
8040+
80378041 Kcur = build_norm(Kcur,
80388042 model.layers[il].attn_k_norm,
80398043 model.layers[il].attn_k_norm_b,
@@ -8416,6 +8420,9 @@ struct llm_build_mpt : public llm_graph_context {
84168420
84178421 // Q/K Layernorm
84188422 if (model.layers[il].attn_q_norm) {
8423+ Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head*n_head, n_tokens);
8424+ Kcur = ggml_reshape_2d(ctx0, Kcur, n_embd_head*n_head_kv, n_tokens);
8425+
84198426 Qcur = build_norm(Qcur,
84208427 model.layers[il].attn_q_norm,
84218428 model.layers[il].attn_q_norm_b,
0 commit comments