@@ -18035,6 +18035,194 @@ struct llm_build_arcee : public llm_graph_context {
1803518035 }
1803618036};
1803718037
18038+ struct llm_build_afmoe : public llm_graph_context {
18039+ llm_build_afmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
18040+ const int64_t n_embd_head = hparams.n_embd_head_v;
18041+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
18042+
18043+ ggml_tensor * cur;
18044+ ggml_tensor * inpL;
18045+
18046+ inpL = build_inp_embd(model.tok_embd);
18047+
18048+ // MuP scaling: embeddings * sqrt(hidden_size)
18049+ // mup_enabled = true, hidden_size = 1024, scale = 32.0
18050+ inpL = ggml_scale(ctx0, inpL, sqrtf(float(n_embd)));
18051+ cb(inpL, "inp_embd_scaled", -1);
18052+
18053+ // inp_pos - contains the positions
18054+ ggml_tensor * inp_pos = build_inp_pos();
18055+ auto * inp_attn = build_attn_inp_kv();
18056+ ggml_tensor * inp_out_ids = build_inp_out_ids();
18057+
18058+ const float kq_scale = 1.0f/sqrtf(float(n_embd_head));
18059+
18060+ for (int il = 0; il < n_layer; ++il) {
18061+ ggml_tensor * inpSA = inpL;
18062+
18063+ // dual attention normalization (pre)
18064+ cur = build_norm(inpL,
18065+ model.layers[il].attn_norm, NULL,
18066+ LLM_NORM_RMS, il);
18067+ cb(cur, "attn_norm", il);
18068+
18069+ // self-attention
18070+ {
18071+ ggml_tensor * attn_inp = cur; // save input for gate computation
18072+
18073+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
18074+ cb(Qcur, "Qcur", il);
18075+
18076+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
18077+ cb(Kcur, "Kcur", il);
18078+
18079+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
18080+ cb(Vcur, "Vcur", il);
18081+
18082+ // compute gate from input
18083+ ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, attn_inp);
18084+ cb(gate, "attn_gate_proj", il);
18085+
18086+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
18087+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
18088+
18089+ // Q/K normalization
18090+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
18091+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
18092+ cb(Qcur, "Qcur_normed", il);
18093+ cb(Kcur, "Kcur_normed", il);
18094+
18095+ // RoPE only for sliding_attention layers (every 4th layer is full_attention)
18096+ // layer_types[i] = "sliding_attention" if (i+1) % global_attn_every_n_layers != 0
18097+ bool is_sliding = ((il + 1) % 4) != 0; // global_attn_every_n_layers = 4
18098+ if (is_sliding) {
18099+ Qcur = ggml_rope_ext(
18100+ ctx0, Qcur, inp_pos, nullptr,
18101+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
18102+ ext_factor, attn_factor, beta_fast, beta_slow);
18103+ cb(Qcur, "Qcur_rope", il);
18104+
18105+ Kcur = ggml_rope_ext(
18106+ ctx0, Kcur, inp_pos, nullptr,
18107+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
18108+ ext_factor, attn_factor, beta_fast, beta_slow);
18109+ cb(Kcur, "Kcur_rope", il);
18110+ }
18111+
18112+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
18113+
18114+ cur = build_attn(inp_attn,
18115+ NULL, NULL, // wo will be applied after gating
18116+ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
18117+ cb(cur, "attn_out", il);
18118+
18119+ // attention gating: attn_out * sigmoid(gate) BEFORE o_proj
18120+ gate = ggml_sigmoid(ctx0, gate);
18121+ cb(gate, "attn_gate_sig", il);
18122+ cur = ggml_mul(ctx0, cur, gate);
18123+ cb(cur, "attn_gated", il);
18124+
18125+ // now apply output projection
18126+ cur = build_lora_mm(model.layers[il].wo, cur);
18127+ cb(cur, "attn_o_proj", il);
18128+ }
18129+
18130+ // dual attention normalization (post)
18131+ cur = build_norm(cur,
18132+ model.layers[il].attn_norm_2, NULL,
18133+ LLM_NORM_RMS, il);
18134+ cb(cur, "attn_norm_2", il);
18135+
18136+ if (il == n_layer - 1 && inp_out_ids) {
18137+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
18138+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
18139+ }
18140+
18141+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
18142+ cb(ffn_inp, "ffn_inp", il);
18143+
18144+ // dual ffn normalization (pre)
18145+ cur = build_norm(ffn_inp,
18146+ model.layers[il].ffn_norm, NULL,
18147+ LLM_NORM_RMS, il);
18148+ cb(cur, "ffn_norm", il);
18149+
18150+ // MoE or dense FFN
18151+ if ((uint32_t)il >= hparams.n_layer_dense_lead) {
18152+ // MoE layer with sigmoid routing, normalization, and scaling
18153+ ggml_tensor * moe_out = build_moe_ffn(cur,
18154+ model.layers[il].ffn_gate_inp,
18155+ model.layers[il].ffn_up_exps,
18156+ model.layers[il].ffn_gate_exps,
18157+ model.layers[il].ffn_down_exps,
18158+ nullptr,
18159+ n_expert, n_expert_used,
18160+ LLM_FFN_SILU,
18161+ hparams.expert_weights_norm != 0, // norm_w (route_norm=True)
18162+ hparams.expert_weights_scale != 0.0f, // scale_w
18163+ hparams.expert_weights_scale, // w_scale (route_scale=2.826)
18164+ (llama_expert_gating_func_type) hparams.expert_gating_func,
18165+ il);
18166+ cb(moe_out, "ffn_moe_out", il);
18167+
18168+ // shared expert
18169+ if (hparams.n_expert_shared > 0) {
18170+ ggml_tensor * ffn_shexp = build_ffn(cur,
18171+ model.layers[il].ffn_up_shexp, NULL, NULL,
18172+ model.layers[il].ffn_gate_shexp, NULL, NULL,
18173+ model.layers[il].ffn_down_shexp, NULL, NULL,
18174+ NULL,
18175+ LLM_FFN_SILU, LLM_FFN_PAR, il);
18176+ cb(ffn_shexp, "ffn_shexp", il);
18177+
18178+ cur = ggml_add(ctx0, moe_out, ffn_shexp);
18179+ cb(cur, "ffn_out", il);
18180+ } else {
18181+ cur = moe_out;
18182+ }
18183+ } else {
18184+ // dense layer
18185+ cur = build_ffn(cur,
18186+ model.layers[il].ffn_up, NULL, NULL,
18187+ model.layers[il].ffn_gate, NULL, NULL,
18188+ model.layers[il].ffn_down, NULL, NULL,
18189+ NULL,
18190+ LLM_FFN_SILU, LLM_FFN_PAR, il);
18191+ cb(cur, "ffn_out", il);
18192+ }
18193+
18194+ // dual ffn normalization (post)
18195+ cur = build_norm(cur,
18196+ model.layers[il].ffn_post_norm, NULL,
18197+ LLM_NORM_RMS, il);
18198+ cb(cur, "ffn_post_norm", il);
18199+
18200+ cur = ggml_add(ctx0, cur, ffn_inp);
18201+ cur = build_cvec(cur, il);
18202+ cb(cur, "l_out", il);
18203+
18204+ // input for next layer
18205+ inpL = cur;
18206+ }
18207+
18208+ cur = inpL;
18209+
18210+ cur = build_norm(cur,
18211+ model.output_norm, NULL,
18212+ LLM_NORM_RMS, -1);
18213+ cb(cur, "result_norm", -1);
18214+
18215+ res->t_embd = cur;
18216+
18217+ // lm_head
18218+ cur = build_lora_mm(model.output, cur);
18219+ cb(cur, "result_output", -1);
18220+ res->t_logits = cur;
18221+
18222+ ggml_build_forward_expand(gf, cur);
18223+ }
18224+ };
18225+
1803818226struct llm_build_hunyuan_moe : public llm_graph_context {
1803918227 llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
1804018228 const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -19667,7 +19855,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1966719855 } break;
1966819856 case LLM_ARCH_AFMOE:
1966919857 {
19670- llm = std::make_unique<llm_build_arcee >(*this, params);
19858+ llm = std::make_unique<llm_build_afmoe >(*this, params);
1967119859 } break;
1967219860 case LLM_ARCH_ERNIE4_5:
1967319861 {
0 commit comments