Skip to content

Commit 9e84fbf

Browse files
CISCpwilkin
authored andcommitted
llama : fix shapes for bert/mpt q/k norm (ggml-org#16409)
1 parent 5acd3e8 commit 9e84fbf

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

src/llama-model.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8025,6 +8025,8 @@ struct llm_build_bert : public llm_graph_context {
80258025
}
80268026

80278027
if (model.layers[il].attn_q_norm) {
8028+
Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head*n_head, n_tokens);
8029+
80288030
Qcur = build_norm(Qcur,
80298031
model.layers[il].attn_q_norm,
80308032
model.layers[il].attn_q_norm_b,
@@ -8034,6 +8036,8 @@ struct llm_build_bert : public llm_graph_context {
80348036
}
80358037

80368038
if (model.layers[il].attn_k_norm) {
8039+
Kcur = ggml_reshape_2d(ctx0, Kcur, n_embd_head*n_head_kv, n_tokens);
8040+
80378041
Kcur = build_norm(Kcur,
80388042
model.layers[il].attn_k_norm,
80398043
model.layers[il].attn_k_norm_b,
@@ -8416,6 +8420,9 @@ struct llm_build_mpt : public llm_graph_context {
84168420

84178421
// Q/K Layernorm
84188422
if (model.layers[il].attn_q_norm) {
8423+
Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head*n_head, n_tokens);
8424+
Kcur = ggml_reshape_2d(ctx0, Kcur, n_embd_head*n_head_kv, n_tokens);
8425+
84198426
Qcur = build_norm(Qcur,
84208427
model.layers[il].attn_q_norm,
84218428
model.layers[il].attn_q_norm_b,

0 commit comments

Comments
 (0)