@@ -90,6 +90,7 @@ const char * llm_type_name(llm_type type) {
9090 case LLM_TYPE_57B_A14B: return "57B.A14B";
9191 case LLM_TYPE_27B: return "27B";
9292 case LLM_TYPE_290B: return "290B";
93+ case LLM_TYPE_17B_16E: return "17Bx16E";
9394 default: return "?B";
9495 }
9596}
@@ -524,10 +525,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
524525 // arch-specific KVs
525526 switch (arch) {
526527 case LLM_ARCH_LLAMA:
527- case LLM_ARCH_LLAMA4:
528528 {
529529 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
530- ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
531530
532531 if (hparams.n_expert == 8) {
533532 switch (hparams.n_layer) {
@@ -552,6 +551,17 @@ void llama_model::load_hparams(llama_model_loader & ml) {
552551 }
553552 }
554553 } break;
554+ case LLM_ARCH_LLAMA4:
555+ {
556+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
557+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
558+ hparams.f_attention_scale = 0.1;
559+
560+ switch (hparams.n_layer) {
561+ case 48: type = LLM_TYPE_17B_16E; break;
562+ default: type = LLM_TYPE_UNKNOWN;
563+ }
564+ } break;
555565 case LLM_ARCH_DECI:
556566 {
557567 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -4266,10 +4276,10 @@ struct llm_build_llama : public llm_graph_context {
42664276
42674277 if (use_rope) {
42684278 Qcur = ggml_rope_ext(
4269- ctx0, Qcur, inp_pos, rope_factors,
4270- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
4271- ext_factor, attn_factor, beta_fast, beta_slow
4272- );
4279+ ctx0, Qcur, inp_pos, rope_factors,
4280+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
4281+ ext_factor, attn_factor, beta_fast, beta_slow
4282+ );
42734283
42744284 Kcur = ggml_rope_ext(
42754285 ctx0, Kcur, inp_pos, rope_factors,
@@ -4278,6 +4288,10 @@ struct llm_build_llama : public llm_graph_context {
42784288 );
42794289 } else {
42804290 // TODO: support temperature tuning (attn_temperature_tuning)
4291+ // Problem: we are missing 2 things:
4292+ // - ggml_cast from I32 to F32
4293+ // - ggml_floor
4294+ // Ref implementation: https://github.com/ml-explore/mlx-lm/blob/9df43c9863c28065fecf87c9be2c5fd7e6f3864c/mlx_lm/models/llama4.py#L122-L130
42814295 }
42824296
42834297 cb(Qcur, "Qcur", il);
0 commit comments