@@ -4507,6 +4507,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
45074507 // but only PROCESS up to last layer (skipping final NextN layer) in forward pass
45084508 for (int i = 0; i < n_layer; ++i) {
45094509 int flags = 0;
4510+
45104511 if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
45114512 // skip all tensors in the NextN layers
45124513 flags |= TENSOR_SKIP;
@@ -13919,6 +13920,144 @@ struct llm_build_glm4_moe : public llm_graph_context {
1391913920 }
1392013921};
1392113922
13923+ struct llm_build_glm4_moe_mtp : public llm_graph_context {
13924+ llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params,
13925+ // For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization
13926+ ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past
13927+ ) : llm_graph_context(params) {
13928+
13929+ const int64_t n_embd_head = hparams.n_embd_head_v;
13930+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
13931+
13932+ // Assuming a single MTP layer at the end
13933+ const int il = hparams.n_layer - 1;
13934+ const auto & mtp_layer = model.layers[il];
13935+
13936+ ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1);
13937+ ggml_set_i32(inp_pos, n_past);
13938+ llm_graph_input_attn_no_cache * inp_attn = nullptr;
13939+
13940+ ggml_tensor * cur;
13941+
13942+ // get MTP embedding for last (conventionally sampled) token
13943+ ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1);
13944+ ggml_set_i32(inp_token_id, last_token_id);
13945+ ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id);
13946+ ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
13947+
13948+ // vLLM l99 previous_hidden_states = self.hnorm(previous_hidden_states)
13949+ ggml_tensor * hidden_state_norm = build_norm(hidden_state_inp, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
13950+
13951+ ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); // torch.cat
13952+ cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj
13953+
13954+
13955+ // now proceed through last layer (skipped in main model)
13956+ ggml_tensor * inpSA = cur;
13957+
13958+ // Pre-attention norm for the MTP block
13959+ ggml_tensor* attn_inp = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il);
13960+
13961+ // self-attention
13962+ {
13963+ ggml_tensor * Qcur = build_lora_mm(mtp_layer.wq, cur);
13964+ if (mtp_layer.bq) {
13965+ Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq);
13966+ }
13967+ cb(Qcur, "Qcur", il);
13968+
13969+ ggml_tensor * Kcur = build_lora_mm(mtp_layer.wk, cur);
13970+ if (mtp_layer.bk) {
13971+ Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk);
13972+ }
13973+ cb(Kcur, "Kcur", il);
13974+
13975+ ggml_tensor * Vcur = build_lora_mm(mtp_layer.wv, cur);
13976+ if (mtp_layer.bv) {
13977+ Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv);
13978+ }
13979+ cb(Vcur, "Vcur", il);
13980+
13981+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
13982+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
13983+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
13984+
13985+ // Apply Q/K norm if available (GLM-4.5 355B variant)
13986+ if (mtp_layer.attn_q_norm) {
13987+ Qcur = build_norm(Qcur, mtp_layer.attn_q_norm, NULL, LLM_NORM_RMS, il);
13988+ cb(Qcur, "Qcur_normed", il);
13989+ }
13990+ if (mtp_layer.attn_k_norm) {
13991+ Kcur = build_norm(Kcur, mtp_layer.attn_k_norm, NULL, LLM_NORM_RMS, il);
13992+ cb(Kcur, "Kcur_normed", il);
13993+ }
13994+
13995+ Qcur = ggml_rope_ext(
13996+ ctx0, Qcur, inp_pos, nullptr,
13997+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
13998+ ext_factor, attn_factor, beta_fast, beta_slow
13999+ );
14000+
14001+ Kcur = ggml_rope_ext(
14002+ ctx0, Kcur, inp_pos, nullptr,
14003+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14004+ ext_factor, attn_factor, beta_fast, beta_slow
14005+ );
14006+
14007+ cb(Qcur, "Qcur", il);
14008+ cb(Kcur, "Kcur", il);
14009+ cb(Vcur, "Vcur", il);
14010+
14011+ cur = build_attn(inp_attn,
14012+ mtp_layer.wo, NULL,
14013+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
14014+ }
14015+
14016+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
14017+
14018+ cur = build_norm(ffn_inp, mtp_layer.attn_post_norm, NULL, LLM_NORM_RMS, il);
14019+
14020+ // moe ffn for nextn block
14021+ {
14022+ // Process routed experts using existing MoE infrastructure
14023+ ggml_tensor * routed_out = build_moe_ffn(cur,
14024+ mtp_layer.ffn_gate_inp,
14025+ mtp_layer.ffn_up_exps,
14026+ mtp_layer.ffn_gate_exps,
14027+ mtp_layer.ffn_down_exps,
14028+ mtp_layer.ffn_exp_probs_b,
14029+ n_expert, n_expert_used,
14030+ LLM_FFN_SILU, hparams.expert_weights_norm,
14031+ true, hparams.expert_weights_scale,
14032+ (llama_expert_gating_func_type) hparams.expert_gating_func,
14033+ il);
14034+ cb(routed_out, "ffn_moe_out", il);
14035+
14036+ // Process shared expert on original input
14037+ ggml_tensor * shared_out = build_ffn(cur,
14038+ mtp_layer.ffn_up_shexp, NULL, NULL,
14039+ mtp_layer.ffn_gate_shexp, NULL, NULL,
14040+ mtp_layer.ffn_down_shexp, NULL, NULL,
14041+ NULL,
14042+ LLM_FFN_SILU, LLM_FFN_PAR, il);
14043+ cb(shared_out, "ffn_shexp_out", il);
14044+
14045+ // Final output: routed_output + shared_output
14046+ cur = ggml_add(ctx0, routed_out, shared_out);
14047+ cb(cur, "ffn_out", il);
14048+ }
14049+
14050+ cur = ggml_add(ctx0, cur, ffn_inp);
14051+
14052+ cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il);
14053+ cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur);
14054+
14055+ res->t_logits = cur;
14056+
14057+ ggml_build_forward_expand(gf, res->t_logits);
14058+ }
14059+ };
14060+
1392214061struct llm_build_nemotron : public llm_graph_context {
1392314062 llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
1392414063 const int64_t n_embd_head = hparams.n_embd_head_v;
0 commit comments