Skip to content

Commit 512bd19

Browse files
bartowski1182ngxson
authored andcommitted
model : Add support for Arcee AI's upcoming AFM model (ggml-org#14185)
* Add Arcee AFM support * Add draft update code * Fix linter and update URL, may still not be final * Update src/llama-model.cpp Co-authored-by: Xuan-Son Nguyen <[email protected]> * Remote accidental blank line --------- Co-authored-by: Xuan-Son Nguyen <[email protected]>
1 parent a820a9e commit 512bd19

File tree

5 files changed

+191
-3
lines changed

5 files changed

+191
-3
lines changed

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ class TOKENIZER_TYPE(IntEnum):
128128
{"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", },
129129
{"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", },
130130
{"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", },
131+
{"name": "arcee", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/arcee-ai/AFM-4.5B", }, # TODO confirm final URL
131132
]
132133

133134
# some models are known to be broken upstream, so we will skip them as exceptions

gguf-py/gguf/constants.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ class MODEL_ARCH(IntEnum):
353353
PLM = auto()
354354
BAILINGMOE = auto()
355355
DOTS1 = auto()
356+
ARCEE = auto()
356357

357358

358359
class VISION_PROJECTOR_TYPE(IntEnum):
@@ -651,7 +652,8 @@ class MODEL_TENSOR(IntEnum):
651652
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
652653
MODEL_ARCH.PLM: "plm",
653654
MODEL_ARCH.BAILINGMOE: "bailingmoe",
654-
MODEL_ARCH.DOTS1: "dots1"
655+
MODEL_ARCH.DOTS1: "dots1",
656+
MODEL_ARCH.ARCEE: "arcee",
655657
}
656658

657659
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -2160,6 +2162,21 @@ class MODEL_TENSOR(IntEnum):
21602162
MODEL_TENSOR.FFN_UP_EXP,
21612163
MODEL_TENSOR.FFN_UP_SHEXP,
21622164
],
2165+
MODEL_ARCH.ARCEE: [
2166+
MODEL_TENSOR.TOKEN_EMBD,
2167+
MODEL_TENSOR.OUTPUT_NORM,
2168+
MODEL_TENSOR.OUTPUT,
2169+
MODEL_TENSOR.ROPE_FREQS,
2170+
MODEL_TENSOR.ATTN_NORM,
2171+
MODEL_TENSOR.ATTN_Q,
2172+
MODEL_TENSOR.ATTN_K,
2173+
MODEL_TENSOR.ATTN_V,
2174+
MODEL_TENSOR.ATTN_OUT,
2175+
MODEL_TENSOR.ATTN_ROT_EMBD,
2176+
MODEL_TENSOR.FFN_NORM,
2177+
MODEL_TENSOR.FFN_DOWN,
2178+
MODEL_TENSOR.FFN_UP,
2179+
],
21632180
# TODO
21642181
}
21652182

src/llama-arch.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
7575
{ LLM_ARCH_PLM, "plm" },
7676
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
7777
{ LLM_ARCH_DOTS1, "dots1" },
78+
{ LLM_ARCH_ARCEE, "arcee" },
7879
{ LLM_ARCH_UNKNOWN, "(unknown)" },
7980
};
8081

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ enum llm_arch {
7979
LLM_ARCH_PLM,
8080
LLM_ARCH_BAILINGMOE,
8181
LLM_ARCH_DOTS1,
82+
LLM_ARCH_ARCEE,
8283
LLM_ARCH_UNKNOWN,
8384
};
8485

src/llama-model.cpp

Lines changed: 170 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4318,6 +4318,37 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
43184318
}
43194319
}
43204320
} break;
4321+
case LLM_ARCH_ARCEE:
4322+
{
4323+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
4324+
4325+
// output
4326+
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
4327+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
4328+
4329+
// if output is NULL, init from the input tok embed
4330+
if (output == NULL) {
4331+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
4332+
}
4333+
4334+
for (int i = 0; i < n_layer; ++i) {
4335+
auto & layer = layers[i];
4336+
4337+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4338+
4339+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
4340+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
4341+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
4342+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
4343+
4344+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
4345+
4346+
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
4347+
4348+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4349+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4350+
}
4351+
} break;
43214352
default:
43224353
throw std::runtime_error("unknown architecture");
43234354
}
@@ -14108,6 +14139,141 @@ struct llm_build_dots1 : public llm_graph_context {
1410814139
}
1410914140
};
1411014141

14142+
struct llm_build_arcee : public llm_graph_context {
14143+
llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
14144+
const int64_t n_embd_head = hparams.n_embd_head_v;
14145+
14146+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
14147+
GGML_ASSERT(n_embd_head == hparams.n_rot);
14148+
14149+
ggml_tensor * cur;
14150+
ggml_tensor * inpL;
14151+
14152+
inpL = build_inp_embd(model.tok_embd);
14153+
14154+
// inp_pos - contains the positions
14155+
ggml_tensor * inp_pos = build_inp_pos();
14156+
14157+
auto * inp_attn = build_attn_inp_kv_unified();
14158+
14159+
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
14160+
14161+
for (int il = 0; il < n_layer; ++il) {
14162+
ggml_tensor * inpSA = inpL;
14163+
14164+
// norm
14165+
cur = build_norm(inpL,
14166+
model.layers[il].attn_norm, NULL,
14167+
LLM_NORM_RMS, il);
14168+
cb(cur, "attn_norm", il);
14169+
14170+
// self-attention
14171+
{
14172+
// rope freq factors for llama3; may return nullptr for llama2 and other models
14173+
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
14174+
14175+
// compute Q and K and RoPE them
14176+
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
14177+
cb(Qcur, "Qcur", il);
14178+
if (model.layers[il].bq) {
14179+
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
14180+
cb(Qcur, "Qcur", il);
14181+
}
14182+
14183+
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
14184+
cb(Kcur, "Kcur", il);
14185+
if (model.layers[il].bk) {
14186+
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
14187+
cb(Kcur, "Kcur", il);
14188+
}
14189+
14190+
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
14191+
cb(Vcur, "Vcur", il);
14192+
if (model.layers[il].bv) {
14193+
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
14194+
cb(Vcur, "Vcur", il);
14195+
}
14196+
14197+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
14198+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
14199+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
14200+
14201+
Qcur = ggml_rope_ext(
14202+
ctx0, Qcur, inp_pos, rope_factors,
14203+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14204+
ext_factor, attn_factor, beta_fast, beta_slow
14205+
);
14206+
14207+
Kcur = ggml_rope_ext(
14208+
ctx0, Kcur, inp_pos, rope_factors,
14209+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14210+
ext_factor, attn_factor, beta_fast, beta_slow
14211+
);
14212+
14213+
cb(Qcur, "Qcur", il);
14214+
cb(Kcur, "Kcur", il);
14215+
cb(Vcur, "Vcur", il);
14216+
14217+
cur = build_attn(inp_attn, gf,
14218+
model.layers[il].wo, model.layers[il].bo,
14219+
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
14220+
cb(cur, "attn_out", il);
14221+
}
14222+
14223+
if (il == n_layer - 1) {
14224+
// skip computing output for unused tokens
14225+
ggml_tensor * inp_out_ids = build_inp_out_ids();
14226+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
14227+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
14228+
}
14229+
14230+
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
14231+
cb(ffn_inp, "ffn_inp", il);
14232+
14233+
// feed-forward network
14234+
// ARCEE uses relu^2 instead of silu
14235+
cur = build_norm(ffn_inp,
14236+
model.layers[il].ffn_norm, NULL,
14237+
LLM_NORM_RMS, il);
14238+
cb(cur, "ffn_norm", il);
14239+
14240+
cur = build_ffn(cur,
14241+
model.layers[il].ffn_up, NULL, NULL,
14242+
NULL, NULL, NULL,
14243+
model.layers[il].ffn_down, NULL, NULL,
14244+
NULL,
14245+
LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il);
14246+
cb(cur, "ffn_out", il);
14247+
14248+
cur = ggml_add(ctx0, cur, ffn_inp);
14249+
cb(cur, "ffn_out", il);
14250+
14251+
cur = build_cvec(cur, il);
14252+
cb(cur, "l_out", il);
14253+
14254+
// input for next layer
14255+
inpL = cur;
14256+
}
14257+
14258+
cur = inpL;
14259+
14260+
cur = build_norm(cur,
14261+
model.output_norm, NULL,
14262+
LLM_NORM_RMS, -1);
14263+
14264+
cb(cur, "result_norm", -1);
14265+
res->t_embd = cur;
14266+
14267+
// lm_head
14268+
cur = build_lora_mm(model.output, cur);
14269+
14270+
cb(cur, "result_output", -1);
14271+
res->t_logits = cur;
14272+
14273+
ggml_build_forward_expand(gf, cur);
14274+
}
14275+
};
14276+
1411114277
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
1411214278
llama_memory_i * res;
1411314279

@@ -14479,6 +14645,10 @@ llm_graph_result_ptr llama_model::build_graph(
1447914645
{
1448014646
llm = std::make_unique<llm_build_dots1>(*this, params, gf);
1448114647
} break;
14648+
case LLM_ARCH_ARCEE:
14649+
{
14650+
llm = std::make_unique<llm_build_arcee>(*this, params, gf);
14651+
} break;
1448214652
default:
1448314653
GGML_ABORT("fatal error");
1448414654
}
@@ -14628,9 +14798,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1462814798
case LLM_ARCH_GRANITE_MOE:
1462914799
case LLM_ARCH_CHAMELEON:
1463014800
case LLM_ARCH_BAILINGMOE:
14631-
case LLM_ARCH_NEO_BERT:
1463214801
case LLM_ARCH_ARCEE:
14633-
case LLM_ARCH_ERNIE4_5:
1463414802
return LLAMA_ROPE_TYPE_NORM;
1463514803

1463614804
// the pairs of head values are offset by n_rot/2

0 commit comments

Comments
 (0)