Skip to content

Commit 506d215

Browse files
committed
Add Arcee AFM support
1 parent fb85a28 commit 506d215

File tree

6 files changed

+238
-0
lines changed

6 files changed

+238
-0
lines changed

convert_hf_to_gguf.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2020,6 +2020,20 @@ def prepare_tensors(self):
20202020
raise ValueError(f"Unprocessed experts: {experts}")
20212021

20222022

2023+
@ModelBase.register("ArceeForCausalLM")
2024+
class ArceeModel(LlamaModel):
2025+
model_arch = gguf.MODEL_ARCH.ARCEE
2026+
2027+
def set_gguf_parameters(self):
2028+
super().set_gguf_parameters()
2029+
self._try_set_pooling_type()
2030+
rope_scaling = self.hparams.get("rope_scaling") or {}
2031+
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
2032+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
2033+
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
2034+
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
2035+
2036+
20232037
@ModelBase.register(
20242038
"LlavaForConditionalGeneration", # pixtral
20252039
"Mistral3ForConditionalGeneration", # mistral small 3.1

gguf-py/gguf/constants.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ class MODEL_ARCH(IntEnum):
343343
WAVTOKENIZER_DEC = auto()
344344
PLM = auto()
345345
BAILINGMOE = auto()
346+
ARCEE = auto()
346347

347348

348349
class VISION_PROJECTOR_TYPE(IntEnum):
@@ -623,6 +624,7 @@ class MODEL_TENSOR(IntEnum):
623624
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
624625
MODEL_ARCH.PLM: "plm",
625626
MODEL_ARCH.BAILINGMOE: "bailingmoe",
627+
MODEL_ARCH.ARCEE: "arcee",
626628
}
627629

628630
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -2044,6 +2046,26 @@ class MODEL_TENSOR(IntEnum):
20442046
MODEL_TENSOR.FFN_DOWN_SHEXP,
20452047
MODEL_TENSOR.FFN_UP_SHEXP,
20462048
],
2049+
MODEL_ARCH.ARCEE: [
2050+
MODEL_TENSOR.TOKEN_EMBD,
2051+
MODEL_TENSOR.OUTPUT_NORM,
2052+
MODEL_TENSOR.OUTPUT,
2053+
MODEL_TENSOR.ROPE_FREQS,
2054+
MODEL_TENSOR.ATTN_NORM,
2055+
MODEL_TENSOR.ATTN_Q,
2056+
MODEL_TENSOR.ATTN_K,
2057+
MODEL_TENSOR.ATTN_V,
2058+
MODEL_TENSOR.ATTN_OUT,
2059+
MODEL_TENSOR.ATTN_ROT_EMBD,
2060+
MODEL_TENSOR.FFN_GATE_INP,
2061+
MODEL_TENSOR.FFN_NORM,
2062+
MODEL_TENSOR.FFN_GATE,
2063+
MODEL_TENSOR.FFN_DOWN,
2064+
MODEL_TENSOR.FFN_UP,
2065+
MODEL_TENSOR.FFN_GATE_EXP,
2066+
MODEL_TENSOR.FFN_DOWN_EXP,
2067+
MODEL_TENSOR.FFN_UP_EXP,
2068+
],
20472069
# TODO
20482070
}
20492071

src/llama-arch.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
7272
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
7373
{ LLM_ARCH_PLM, "plm" },
7474
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
75+
{ LLM_ARCH_ARCEE, "arcee" },
7576
{ LLM_ARCH_UNKNOWN, "(unknown)" },
7677
};
7778

@@ -243,6 +244,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
243244
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
244245
},
245246
},
247+
{
248+
LLM_ARCH_ARCEE,
249+
{
250+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
251+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
252+
{ LLM_TENSOR_OUTPUT, "output" },
253+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
254+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
255+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
256+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
257+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
258+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
259+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
260+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
261+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
262+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
263+
},
264+
},
246265
{
247266
LLM_ARCH_LLAMA4,
248267
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ enum llm_arch {
7676
LLM_ARCH_WAVTOKENIZER_DEC,
7777
LLM_ARCH_PLM,
7878
LLM_ARCH_BAILINGMOE,
79+
LLM_ARCH_ARCEE,
7980
LLM_ARCH_UNKNOWN,
8081
};
8182

src/llama-model.cpp

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
598598
hparams.use_kq_norm = false;
599599
}
600600
} break;
601+
case LLM_ARCH_ARCEE:
602+
{
603+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
604+
605+
// Arcee uses the same structure as Llama
606+
switch (hparams.n_layer) {
607+
case 36: type = LLM_TYPE_4B; break;
608+
default: type = LLM_TYPE_UNKNOWN;
609+
}
610+
} break;
601611
case LLM_ARCH_DECI:
602612
{
603613
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -4123,6 +4133,37 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
41234133
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
41244134
}
41254135
} break;
4136+
case LLM_ARCH_ARCEE:
4137+
{
4138+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
4139+
4140+
// output
4141+
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
4142+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
4143+
4144+
// if output is NULL, init from the input tok embed
4145+
if (output == NULL) {
4146+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
4147+
}
4148+
4149+
for (int i = 0; i < n_layer; ++i) {
4150+
auto & layer = layers[i];
4151+
4152+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4153+
4154+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
4155+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
4156+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
4157+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
4158+
4159+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
4160+
4161+
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
4162+
4163+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4164+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4165+
}
4166+
} break;
41264167
default:
41274168
throw std::runtime_error("unknown architecture");
41284169
}
@@ -13194,6 +13235,141 @@ struct llm_build_bailingmoe : public llm_graph_context {
1319413235
}
1319513236
};
1319613237

13238+
struct llm_build_arcee : public llm_graph_context {
13239+
llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
13240+
const int64_t n_embd_head = hparams.n_embd_head_v;
13241+
13242+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
13243+
GGML_ASSERT(n_embd_head == hparams.n_rot);
13244+
13245+
ggml_tensor * cur;
13246+
ggml_tensor * inpL;
13247+
13248+
inpL = build_inp_embd(model.tok_embd);
13249+
13250+
// inp_pos - contains the positions
13251+
ggml_tensor * inp_pos = build_inp_pos();
13252+
13253+
auto * inp_attn = build_attn_inp_kv_unified();
13254+
13255+
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
13256+
13257+
for (int il = 0; il < n_layer; ++il) {
13258+
ggml_tensor * inpSA = inpL;
13259+
13260+
// norm
13261+
cur = build_norm(inpL,
13262+
model.layers[il].attn_norm, NULL,
13263+
LLM_NORM_RMS, il);
13264+
cb(cur, "attn_norm", il);
13265+
13266+
// self-attention
13267+
{
13268+
// rope freq factors for llama3; may return nullptr for llama2 and other models
13269+
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
13270+
13271+
// compute Q and K and RoPE them
13272+
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
13273+
cb(Qcur, "Qcur", il);
13274+
if (model.layers[il].bq) {
13275+
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
13276+
cb(Qcur, "Qcur", il);
13277+
}
13278+
13279+
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
13280+
cb(Kcur, "Kcur", il);
13281+
if (model.layers[il].bk) {
13282+
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
13283+
cb(Kcur, "Kcur", il);
13284+
}
13285+
13286+
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
13287+
cb(Vcur, "Vcur", il);
13288+
if (model.layers[il].bv) {
13289+
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
13290+
cb(Vcur, "Vcur", il);
13291+
}
13292+
13293+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
13294+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
13295+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
13296+
13297+
Qcur = ggml_rope_ext(
13298+
ctx0, Qcur, inp_pos, rope_factors,
13299+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
13300+
ext_factor, attn_factor, beta_fast, beta_slow
13301+
);
13302+
13303+
Kcur = ggml_rope_ext(
13304+
ctx0, Kcur, inp_pos, rope_factors,
13305+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
13306+
ext_factor, attn_factor, beta_fast, beta_slow
13307+
);
13308+
13309+
cb(Qcur, "Qcur", il);
13310+
cb(Kcur, "Kcur", il);
13311+
cb(Vcur, "Vcur", il);
13312+
13313+
cur = build_attn(inp_attn, gf,
13314+
model.layers[il].wo, model.layers[il].bo,
13315+
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
13316+
cb(cur, "attn_out", il);
13317+
}
13318+
13319+
if (il == n_layer - 1) {
13320+
// skip computing output for unused tokens
13321+
ggml_tensor * inp_out_ids = build_inp_out_ids();
13322+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
13323+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
13324+
}
13325+
13326+
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
13327+
cb(ffn_inp, "ffn_inp", il);
13328+
13329+
// feed-forward network
13330+
// ARCEE uses relu^2 instead of swiglu
13331+
cur = build_norm(ffn_inp,
13332+
model.layers[il].ffn_norm, NULL,
13333+
LLM_NORM_RMS, il);
13334+
cb(cur, "ffn_norm", il);
13335+
13336+
cur = build_ffn(cur,
13337+
model.layers[il].ffn_up, NULL, NULL,
13338+
NULL, NULL, NULL,
13339+
model.layers[il].ffn_down, NULL, NULL,
13340+
NULL,
13341+
LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il);
13342+
cb(cur, "ffn_out", il);
13343+
13344+
cur = ggml_add(ctx0, cur, ffn_inp);
13345+
cb(cur, "ffn_out", il);
13346+
13347+
cur = build_cvec(cur, il);
13348+
cb(cur, "l_out", il);
13349+
13350+
// input for next layer
13351+
inpL = cur;
13352+
}
13353+
13354+
cur = inpL;
13355+
13356+
cur = build_norm(cur,
13357+
model.output_norm, NULL,
13358+
LLM_NORM_RMS, -1);
13359+
13360+
cb(cur, "result_norm", -1);
13361+
res->t_embd = cur;
13362+
13363+
// lm_head
13364+
cur = build_lora_mm(model.output, cur);
13365+
13366+
cb(cur, "result_output", -1);
13367+
res->t_logits = cur;
13368+
13369+
ggml_build_forward_expand(gf, cur);
13370+
}
13371+
};
13372+
1319713373
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
1319813374
llama_memory_i * res;
1319913375

@@ -13532,6 +13708,10 @@ llm_graph_result_ptr llama_model::build_graph(
1353213708
{
1353313709
llm = std::make_unique<llm_build_bailingmoe>(*this, params, gf);
1353413710
} break;
13711+
case LLM_ARCH_ARCEE:
13712+
{
13713+
llm = std::make_unique<llm_build_arcee>(*this, params, gf);
13714+
} break;
1353513715
default:
1353613716
GGML_ABORT("fatal error");
1353713717
}
@@ -13681,6 +13861,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1368113861
case LLM_ARCH_GRANITE_MOE:
1368213862
case LLM_ARCH_CHAMELEON:
1368313863
case LLM_ARCH_BAILINGMOE:
13864+
case LLM_ARCH_ARCEE:
1368413865
return LLAMA_ROPE_TYPE_NORM;
1368513866

1368613867
// the pairs of head values are offset by n_rot/2

src/llama-vocab.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1987,6 +1987,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
19871987
|| t.first == "<|eom_id|>"
19881988
|| t.first == "<EOT>"
19891989
|| t.first == "_<EOT>"
1990+
|| t.first == "<|end_of_text|>"
19901991
) {
19911992
special_eog_ids.insert(t.second);
19921993
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {

0 commit comments

Comments
 (0)