@@ -2513,7 +2513,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
25132513 layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
25142514 layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
25152515
2516- layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
2516+ layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
2517+ if (!layer.ffn_post_norm) {
2518+ layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
2519+ }
25172520 }
25182521 } break;
25192522 case LLM_ARCH_DBRX:
@@ -6974,14 +6977,10 @@ struct llm_build_grok : public llm_graph_context {
69746977 inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
69756978 }
69766979
6977- // Grok
6978- // if attn_out_norm is present then apply it before adding the input
6979- if (model.layers[il].attn_out_norm) {
6980- cur = build_norm(cur,
6981- model.layers[il].attn_out_norm, NULL,
6982- LLM_NORM_RMS, il);
6983- cb(cur, "attn_out_norm", il);
6984- }
6980+ cur = build_norm(cur,
6981+ model.layers[il].attn_out_norm, NULL,
6982+ LLM_NORM_RMS, il);
6983+ cb(cur, "attn_out_norm", il);
69856984
69866985 ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
69876986 cb(ffn_inp, "ffn_inp", il);
@@ -7006,15 +7005,10 @@ struct llm_build_grok : public llm_graph_context {
70067005 il);
70077006 cb(cur, "ffn_moe_out", il);
70087007
7009- // Grok
7010- // if layer_out_norm is present then apply it before adding the input
7011- // Idea: maybe ffn_out_norm is a better name
7012- if (model.layers[il].layer_out_norm) {
7013- cur = build_norm(cur,
7014- model.layers[il].layer_out_norm, NULL,
7015- LLM_NORM_RMS, il);
7016- cb(cur, "layer_out_norm", il);
7017- }
7008+ cur = build_norm(cur,
7009+ model.layers[il].ffn_post_norm, NULL,
7010+ LLM_NORM_RMS, il);
7011+ cb(cur, "ffn_post_norm", il);
70187012
70197013 cur = ggml_add(ctx0, cur, ffn_inp);
70207014 cb(cur, "ffn_out", il);
0 commit comments