Skip to content

Commit f6d8e75

Browse files
committed
clean up a bit
1 parent b19dbd0 commit f6d8e75

File tree

3 files changed

+22
-6
lines changed

3 files changed

+22
-6
lines changed

src/llama-hparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ struct llama_hparams {
116116
uint32_t interleave_moe_layer_step = 2; // TODO read from gguf
117117
uint32_t no_rope_layer_interval = 4; // TODO read from gguf
118118
uint32_t attn_temperature_tuning = 4; // TODO read from gguf
119+
uint32_t floor_scale = 8192; // TODO read from gguf
119120

120121
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
121122
// ref: https://github.com/ggerganov/llama.cpp/pull/8141

src/llama-model.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ const char * llm_type_name(llm_type type) {
9090
case LLM_TYPE_57B_A14B: return "57B.A14B";
9191
case LLM_TYPE_27B: return "27B";
9292
case LLM_TYPE_290B: return "290B";
93+
case LLM_TYPE_17B_16E: return "17Bx16E";
9394
default: return "?B";
9495
}
9596
}
@@ -524,10 +525,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
524525
// arch-specific KVs
525526
switch (arch) {
526527
case LLM_ARCH_LLAMA:
527-
case LLM_ARCH_LLAMA4:
528528
{
529529
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
530-
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
531530

532531
if (hparams.n_expert == 8) {
533532
switch (hparams.n_layer) {
@@ -552,6 +551,17 @@ void llama_model::load_hparams(llama_model_loader & ml) {
552551
}
553552
}
554553
} break;
554+
case LLM_ARCH_LLAMA4:
555+
{
556+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
557+
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
558+
hparams.f_attention_scale = 0.1;
559+
560+
switch (hparams.n_layer) {
561+
case 48: type = LLM_TYPE_17B_16E; break;
562+
default: type = LLM_TYPE_UNKNOWN;
563+
}
564+
} break;
555565
case LLM_ARCH_DECI:
556566
{
557567
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -4266,10 +4276,10 @@ struct llm_build_llama : public llm_graph_context {
42664276

42674277
if (use_rope) {
42684278
Qcur = ggml_rope_ext(
4269-
ctx0, Qcur, inp_pos, rope_factors,
4270-
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
4271-
ext_factor, attn_factor, beta_fast, beta_slow
4272-
);
4279+
ctx0, Qcur, inp_pos, rope_factors,
4280+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
4281+
ext_factor, attn_factor, beta_fast, beta_slow
4282+
);
42734283

42744284
Kcur = ggml_rope_ext(
42754285
ctx0, Kcur, inp_pos, rope_factors,
@@ -4278,6 +4288,10 @@ struct llm_build_llama : public llm_graph_context {
42784288
);
42794289
} else {
42804290
// TODO: support temperature tuning (attn_temperature_tuning)
4291+
// Problem: we are missing 2 things:
4292+
// - ggml_cast from I32 to F32
4293+
// - ggml_floor
4294+
// Ref implementation: https://github.com/ml-explore/mlx-lm/blob/9df43c9863c28065fecf87c9be2c5fd7e6f3864c/mlx_lm/models/llama4.py#L122-L130
42814295
}
42824296

42834297
cb(Qcur, "Qcur", il);

src/llama-model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ enum llm_type {
8686
LLM_TYPE_57B_A14B,
8787
LLM_TYPE_27B,
8888
LLM_TYPE_290B,
89+
LLM_TYPE_17B_16E, // llama4 Scout
8990
};
9091

9192
struct llama_layer_posnet {

0 commit comments

Comments
 (0)