@@ -524,8 +524,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
524524 // arch-specific KVs
525525 switch (arch) {
526526 case LLM_ARCH_LLAMA:
527+ case LLM_ARCH_LLAMA4:
527528 {
528529 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
530+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
529531
530532 if (hparams.n_expert == 8) {
531533 switch (hparams.n_layer) {
@@ -1631,6 +1633,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
16311633 const auto tn = LLM_TN(arch);
16321634 switch (arch) {
16331635 case LLM_ARCH_LLAMA:
1636+ case LLM_ARCH_LLAMA4:
16341637 case LLM_ARCH_REFACT:
16351638 case LLM_ARCH_MINICPM:
16361639 case LLM_ARCH_GRANITE:
@@ -1648,6 +1651,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
16481651 }
16491652
16501653 for (int i = 0; i < n_layer; ++i) {
1654+ bool is_moe_layer = (i + 1) % hparams.interleave_moe_layer_step == 0;
1655+
16511656 auto & layer = layers[i];
16521657
16531658 layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
@@ -1673,7 +1678,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
16731678 layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
16741679 }
16751680
1676- if (n_expert == 0) {
1681+ int n_ff_exp = hparams.n_ff_exp;
1682+ if (n_expert == 0 || !is_moe_layer) {
16771683 layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
16781684 layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
16791685 layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
@@ -1684,9 +1690,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
16841690 layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
16851691 } else {
16861692 layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
1687- layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED);
1688- layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
1689- layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
1693+ layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED);
1694+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0);
1695+ layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
1696+
1697+ // Shared expert branch (only used by llama 4 for now)
1698+ if (arch == LLM_ARCH_LLAMA4) {
1699+ const int64_t n_ff_shexp = n_ff_exp;
1700+ layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
1701+ layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0);
1702+ layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
1703+ }
16901704 }
16911705 }
16921706 } break;
@@ -4209,6 +4223,10 @@ struct llm_build_llama : public llm_graph_context {
42094223 for (int il = 0; il < n_layer; ++il) {
42104224 ggml_tensor * inpSA = inpL;
42114225
4226+ bool use_rope = arch == LLM_ARCH_LLAMA4
4227+ ? (il + 1) % hparams.no_rope_layer_interval != 0
4228+ : true;
4229+
42124230 // norm
42134231 cur = build_norm(inpL,
42144232 model.layers[il].attn_norm, NULL,
@@ -4246,25 +4264,39 @@ struct llm_build_llama : public llm_graph_context {
42464264 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
42474265 Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
42484266
4249- Qcur = ggml_rope_ext(
4267+ if (use_rope) {
4268+ Qcur = ggml_rope_ext(
42504269 ctx0, Qcur, inp_pos, rope_factors,
42514270 n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
42524271 ext_factor, attn_factor, beta_fast, beta_slow
42534272 );
42544273
4255- Kcur = ggml_rope_ext(
4256- ctx0, Kcur, inp_pos, rope_factors,
4257- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
4258- ext_factor, attn_factor, beta_fast, beta_slow
4259- );
4274+ Kcur = ggml_rope_ext(
4275+ ctx0, Kcur, inp_pos, rope_factors,
4276+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
4277+ ext_factor, attn_factor, beta_fast, beta_slow
4278+ );
4279+ } else {
4280+ // TODO: support temperature tuning (attn_temperature_tuning)
4281+ }
42604282
42614283 cb(Qcur, "Qcur", il);
42624284 cb(Kcur, "Kcur", il);
42634285 cb(Vcur, "Vcur", il);
42644286
4287+ if (arch == LLM_ARCH_LLAMA4 && use_rope) {
4288+ // Llama4TextL2Norm
4289+ // TODO @ngxson : the 128E model does not use qk_norm
4290+ Qcur = ggml_rms_norm(ctx0, Qcur, 1e-6);
4291+ Kcur = ggml_rms_norm(ctx0, Kcur, 1e-6);
4292+ cb(Qcur, "Qcur_normed", il);
4293+ cb(Kcur, "Kcur_normed", il);
4294+ }
4295+
42654296 cur = build_attn(inp_attn, gf,
42664297 model.layers[il].wo, model.layers[il].bo,
42674298 Qcur, Kcur, Vcur, nullptr, kq_scale, il);
4299+ cb(cur, "attn_out", il);
42684300 }
42694301
42704302 if (il == n_layer - 1) {
@@ -4282,7 +4314,7 @@ struct llm_build_llama : public llm_graph_context {
42824314 ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
42834315 cb(ffn_inp, "ffn_inp", il);
42844316
4285- // feed-forward network
4317+ // feed-forward network (non-MoE)
42864318 if (model.layers[il].ffn_gate_inp == nullptr) {
42874319
42884320 cur = build_norm(ffn_inp,
@@ -4297,6 +4329,35 @@ struct llm_build_llama : public llm_graph_context {
42974329 NULL,
42984330 LLM_FFN_SILU, LLM_FFN_PAR, il);
42994331 cb(cur, "ffn_out", il);
4332+
4333+ } else if (arch == LLM_ARCH_LLAMA4) {
4334+ // llama4 MoE
4335+ cur = build_norm(ffn_inp,
4336+ model.layers[il].ffn_norm, NULL,
4337+ LLM_NORM_RMS, il);
4338+ cb(cur, "ffn_norm", il);
4339+
4340+ cur = build_moe_ffn(cur,
4341+ model.layers[il].ffn_gate_inp,
4342+ model.layers[il].ffn_up_exps,
4343+ model.layers[il].ffn_gate_exps,
4344+ model.layers[il].ffn_down_exps,
4345+ nullptr,
4346+ n_expert, n_expert_used,
4347+ LLM_FFN_SILU, false,
4348+ false, 0.0,
4349+ LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID,
4350+ il);
4351+
4352+ // Shared experts
4353+ cur = build_ffn(cur,
4354+ model.layers[il].ffn_up_shexp, NULL, NULL,
4355+ model.layers[il].ffn_gate_shexp, NULL, NULL,
4356+ model.layers[il].ffn_down_shexp, NULL, NULL,
4357+ NULL,
4358+ LLM_FFN_SILU, LLM_FFN_PAR, il);
4359+ cb(cur, "ffn_moe_shexp", il);
4360+
43004361 } else {
43014362 // MoE branch
43024363 cur = build_norm(ffn_inp,
@@ -12091,6 +12152,7 @@ llm_graph_result_ptr llama_model::build_graph(
1209112152
1209212153 switch (arch) {
1209312154 case LLM_ARCH_LLAMA:
12155+ case LLM_ARCH_LLAMA4:
1209412156 case LLM_ARCH_MINICPM:
1209512157 case LLM_ARCH_GRANITE:
1209612158 case LLM_ARCH_GRANITE_MOE:
@@ -12440,6 +12502,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1244012502
1244112503 // use what we call a normal RoPE, operating on pairs of consecutive head values
1244212504 case LLM_ARCH_LLAMA:
12505+ case LLM_ARCH_LLAMA4:
1244312506 case LLM_ARCH_DECI:
1244412507 case LLM_ARCH_BAICHUAN:
1244512508 case LLM_ARCH_STARCODER:
0 commit comments