Skip to content

Commit b19dbd0

Browse files
committed
initial support, no chat template
1 parent 79ebef8 commit b19dbd0

File tree

7 files changed

+121
-12
lines changed

7 files changed

+121
-12
lines changed

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ extern "C" {
110110
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
111111
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
112112
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
113+
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
113114
};
114115

115116
enum llama_rope_type {

src/llama-arch.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
88
{ LLM_ARCH_LLAMA, "llama" },
9+
{ LLM_ARCH_LLAMA4, "llama4" },
910
{ LLM_ARCH_DECI, "deci" },
1011
{ LLM_ARCH_FALCON, "falcon" },
1112
{ LLM_ARCH_GROK, "grok" },
@@ -233,6 +234,35 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
233234
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
234235
},
235236
},
237+
{
238+
LLM_ARCH_LLAMA4,
239+
{
240+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
241+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
242+
{ LLM_TENSOR_OUTPUT, "output" },
243+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
244+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
245+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
246+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
247+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
248+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
249+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
250+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
251+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
252+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
253+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
254+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
255+
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
256+
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
257+
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
258+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
259+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
260+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
261+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
262+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
263+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
264+
},
265+
},
236266
{
237267
LLM_ARCH_DECI,
238268
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
enum llm_arch {
1212
LLM_ARCH_LLAMA,
13+
LLM_ARCH_LLAMA4,
1314
LLM_ARCH_DECI,
1415
LLM_ARCH_FALCON,
1516
LLM_ARCH_BAICHUAN,

src/llama-graph.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,12 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
841841
cb(selection_probs, "ffn_moe_probs_biased", il);
842842
}
843843

844+
// llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
845+
// see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
846+
if (arch == LLM_ARCH_LLAMA4) {
847+
selection_probs = logits;
848+
}
849+
844850
// select experts
845851
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
846852
cb(selected_experts->src[0], "ffn_moe_argsort", il);
@@ -914,6 +920,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
914920
moe_out = ggml_cont(ctx0, moe_out);
915921
}
916922

923+
cb(moe_out, "ffn_moe_out", il);
924+
917925
return moe_out;
918926
}
919927

src/llama-hparams.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@ struct llama_hparams {
112112
bool use_alibi = false;
113113
bool attn_soft_cap = false;
114114

115+
// TODO @ngxson : variable names taken from python code, we can rename it later
116+
uint32_t interleave_moe_layer_step = 2; // TODO read from gguf
117+
uint32_t no_rope_layer_interval = 4; // TODO read from gguf
118+
uint32_t attn_temperature_tuning = 4; // TODO read from gguf
119+
115120
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
116121
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
117122
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;

src/llama-model.cpp

Lines changed: 74 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -524,8 +524,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
524524
// arch-specific KVs
525525
switch (arch) {
526526
case LLM_ARCH_LLAMA:
527+
case LLM_ARCH_LLAMA4:
527528
{
528529
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
530+
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
529531

530532
if (hparams.n_expert == 8) {
531533
switch (hparams.n_layer) {
@@ -1631,6 +1633,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
16311633
const auto tn = LLM_TN(arch);
16321634
switch (arch) {
16331635
case LLM_ARCH_LLAMA:
1636+
case LLM_ARCH_LLAMA4:
16341637
case LLM_ARCH_REFACT:
16351638
case LLM_ARCH_MINICPM:
16361639
case LLM_ARCH_GRANITE:
@@ -1648,6 +1651,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
16481651
}
16491652

16501653
for (int i = 0; i < n_layer; ++i) {
1654+
bool is_moe_layer = (i + 1) % hparams.interleave_moe_layer_step == 0;
1655+
16511656
auto & layer = layers[i];
16521657

16531658
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
@@ -1673,7 +1678,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
16731678
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
16741679
}
16751680

1676-
if (n_expert == 0) {
1681+
int n_ff_exp = hparams.n_ff_exp;
1682+
if (n_expert == 0 || !is_moe_layer) {
16771683
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
16781684
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
16791685
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
@@ -1684,9 +1690,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
16841690
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
16851691
} else {
16861692
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
1687-
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED);
1688-
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
1689-
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
1693+
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED);
1694+
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0);
1695+
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
1696+
1697+
// Shared expert branch (only used by llama 4 for now)
1698+
if (arch == LLM_ARCH_LLAMA4) {
1699+
const int64_t n_ff_shexp = n_ff_exp;
1700+
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
1701+
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0);
1702+
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
1703+
}
16901704
}
16911705
}
16921706
} break;
@@ -4209,6 +4223,10 @@ struct llm_build_llama : public llm_graph_context {
42094223
for (int il = 0; il < n_layer; ++il) {
42104224
ggml_tensor * inpSA = inpL;
42114225

4226+
bool use_rope = arch == LLM_ARCH_LLAMA4
4227+
? (il + 1) % hparams.no_rope_layer_interval != 0
4228+
: true;
4229+
42124230
// norm
42134231
cur = build_norm(inpL,
42144232
model.layers[il].attn_norm, NULL,
@@ -4246,25 +4264,39 @@ struct llm_build_llama : public llm_graph_context {
42464264
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
42474265
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
42484266

4249-
Qcur = ggml_rope_ext(
4267+
if (use_rope) {
4268+
Qcur = ggml_rope_ext(
42504269
ctx0, Qcur, inp_pos, rope_factors,
42514270
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
42524271
ext_factor, attn_factor, beta_fast, beta_slow
42534272
);
42544273

4255-
Kcur = ggml_rope_ext(
4256-
ctx0, Kcur, inp_pos, rope_factors,
4257-
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
4258-
ext_factor, attn_factor, beta_fast, beta_slow
4259-
);
4274+
Kcur = ggml_rope_ext(
4275+
ctx0, Kcur, inp_pos, rope_factors,
4276+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
4277+
ext_factor, attn_factor, beta_fast, beta_slow
4278+
);
4279+
} else {
4280+
// TODO: support temperature tuning (attn_temperature_tuning)
4281+
}
42604282

42614283
cb(Qcur, "Qcur", il);
42624284
cb(Kcur, "Kcur", il);
42634285
cb(Vcur, "Vcur", il);
42644286

4287+
if (arch == LLM_ARCH_LLAMA4 && use_rope) {
4288+
// Llama4TextL2Norm
4289+
// TODO @ngxson : the 128E model does not use qk_norm
4290+
Qcur = ggml_rms_norm(ctx0, Qcur, 1e-6);
4291+
Kcur = ggml_rms_norm(ctx0, Kcur, 1e-6);
4292+
cb(Qcur, "Qcur_normed", il);
4293+
cb(Kcur, "Kcur_normed", il);
4294+
}
4295+
42654296
cur = build_attn(inp_attn, gf,
42664297
model.layers[il].wo, model.layers[il].bo,
42674298
Qcur, Kcur, Vcur, nullptr, kq_scale, il);
4299+
cb(cur, "attn_out", il);
42684300
}
42694301

42704302
if (il == n_layer - 1) {
@@ -4282,7 +4314,7 @@ struct llm_build_llama : public llm_graph_context {
42824314
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
42834315
cb(ffn_inp, "ffn_inp", il);
42844316

4285-
// feed-forward network
4317+
// feed-forward network (non-MoE)
42864318
if (model.layers[il].ffn_gate_inp == nullptr) {
42874319

42884320
cur = build_norm(ffn_inp,
@@ -4297,6 +4329,35 @@ struct llm_build_llama : public llm_graph_context {
42974329
NULL,
42984330
LLM_FFN_SILU, LLM_FFN_PAR, il);
42994331
cb(cur, "ffn_out", il);
4332+
4333+
} else if (arch == LLM_ARCH_LLAMA4) {
4334+
// llama4 MoE
4335+
cur = build_norm(ffn_inp,
4336+
model.layers[il].ffn_norm, NULL,
4337+
LLM_NORM_RMS, il);
4338+
cb(cur, "ffn_norm", il);
4339+
4340+
cur = build_moe_ffn(cur,
4341+
model.layers[il].ffn_gate_inp,
4342+
model.layers[il].ffn_up_exps,
4343+
model.layers[il].ffn_gate_exps,
4344+
model.layers[il].ffn_down_exps,
4345+
nullptr,
4346+
n_expert, n_expert_used,
4347+
LLM_FFN_SILU, false,
4348+
false, 0.0,
4349+
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID,
4350+
il);
4351+
4352+
// Shared experts
4353+
cur = build_ffn(cur,
4354+
model.layers[il].ffn_up_shexp, NULL, NULL,
4355+
model.layers[il].ffn_gate_shexp, NULL, NULL,
4356+
model.layers[il].ffn_down_shexp, NULL, NULL,
4357+
NULL,
4358+
LLM_FFN_SILU, LLM_FFN_PAR, il);
4359+
cb(cur, "ffn_moe_shexp", il);
4360+
43004361
} else {
43014362
// MoE branch
43024363
cur = build_norm(ffn_inp,
@@ -12091,6 +12152,7 @@ llm_graph_result_ptr llama_model::build_graph(
1209112152

1209212153
switch (arch) {
1209312154
case LLM_ARCH_LLAMA:
12155+
case LLM_ARCH_LLAMA4:
1209412156
case LLM_ARCH_MINICPM:
1209512157
case LLM_ARCH_GRANITE:
1209612158
case LLM_ARCH_GRANITE_MOE:
@@ -12440,6 +12502,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1244012502

1244112503
// use what we call a normal RoPE, operating on pairs of consecutive head values
1244212504
case LLM_ARCH_LLAMA:
12505+
case LLM_ARCH_LLAMA4:
1244312506
case LLM_ARCH_DECI:
1244412507
case LLM_ARCH_BAICHUAN:
1244512508
case LLM_ARCH_STARCODER:

src/llama-vocab.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1616,7 +1616,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
16161616
tokenizer_pre == "megrez") {
16171617
pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
16181618
} else if (
1619-
tokenizer_pre == "gpt-4o") {
1619+
tokenizer_pre == "gpt-4o" ||
1620+
tokenizer_pre == "llama4") {
16201621
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
16211622
clean_spaces = false;
16221623
} else if (

0 commit comments

Comments
 (0)