Skip to content

Commit e434f87

Browse files
committed
some work towards building mtp layer graph
1 parent db60623 commit e434f87

File tree

2 files changed

+149
-8
lines changed

2 files changed

+149
-8
lines changed

src/llama-model.cpp

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4507,6 +4507,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
45074507
// but only PROCESS up to last layer (skipping final NextN layer) in forward pass
45084508
for (int i = 0; i < n_layer; ++i) {
45094509
int flags = 0;
4510+
45104511
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
45114512
// skip all tensors in the NextN layers
45124513
flags |= TENSOR_SKIP;
@@ -13919,6 +13920,144 @@ struct llm_build_glm4_moe : public llm_graph_context {
1391913920
}
1392013921
};
1392113922

13923+
struct llm_build_glm4_moe_mtp : public llm_graph_context {
13924+
llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params,
13925+
// For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization
13926+
ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past
13927+
) : llm_graph_context(params) {
13928+
13929+
const int64_t n_embd_head = hparams.n_embd_head_v;
13930+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
13931+
13932+
// Assuming a single MTP layer at the end
13933+
const int il = hparams.n_layer - 1;
13934+
const auto & mtp_layer = model.layers[il];
13935+
13936+
ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1);
13937+
ggml_set_i32(inp_pos, n_past);
13938+
llm_graph_input_attn_no_cache * inp_attn = nullptr;
13939+
13940+
ggml_tensor * cur;
13941+
13942+
// get MTP embedding for last (conventionally sampled) token
13943+
ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1);
13944+
ggml_set_i32(inp_token_id, last_token_id);
13945+
ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id);
13946+
ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
13947+
13948+
// vLLM l99 previous_hidden_states = self.hnorm(previous_hidden_states)
13949+
ggml_tensor * hidden_state_norm = build_norm(hidden_state_inp, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
13950+
13951+
ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); // torch.cat
13952+
cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj
13953+
13954+
13955+
// now proceed through last layer (skipped in main model)
13956+
ggml_tensor * inpSA = cur;
13957+
13958+
// Pre-attention norm for the MTP block
13959+
ggml_tensor* attn_inp = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il);
13960+
13961+
// self-attention
13962+
{
13963+
ggml_tensor * Qcur = build_lora_mm(mtp_layer.wq, cur);
13964+
if (mtp_layer.bq) {
13965+
Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq);
13966+
}
13967+
cb(Qcur, "Qcur", il);
13968+
13969+
ggml_tensor * Kcur = build_lora_mm(mtp_layer.wk, cur);
13970+
if (mtp_layer.bk) {
13971+
Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk);
13972+
}
13973+
cb(Kcur, "Kcur", il);
13974+
13975+
ggml_tensor * Vcur = build_lora_mm(mtp_layer.wv, cur);
13976+
if (mtp_layer.bv) {
13977+
Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv);
13978+
}
13979+
cb(Vcur, "Vcur", il);
13980+
13981+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
13982+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
13983+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
13984+
13985+
// Apply Q/K norm if available (GLM-4.5 355B variant)
13986+
if (mtp_layer.attn_q_norm) {
13987+
Qcur = build_norm(Qcur, mtp_layer.attn_q_norm, NULL, LLM_NORM_RMS, il);
13988+
cb(Qcur, "Qcur_normed", il);
13989+
}
13990+
if (mtp_layer.attn_k_norm) {
13991+
Kcur = build_norm(Kcur, mtp_layer.attn_k_norm, NULL, LLM_NORM_RMS, il);
13992+
cb(Kcur, "Kcur_normed", il);
13993+
}
13994+
13995+
Qcur = ggml_rope_ext(
13996+
ctx0, Qcur, inp_pos, nullptr,
13997+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
13998+
ext_factor, attn_factor, beta_fast, beta_slow
13999+
);
14000+
14001+
Kcur = ggml_rope_ext(
14002+
ctx0, Kcur, inp_pos, nullptr,
14003+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14004+
ext_factor, attn_factor, beta_fast, beta_slow
14005+
);
14006+
14007+
cb(Qcur, "Qcur", il);
14008+
cb(Kcur, "Kcur", il);
14009+
cb(Vcur, "Vcur", il);
14010+
14011+
cur = build_attn(inp_attn,
14012+
mtp_layer.wo, NULL,
14013+
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
14014+
}
14015+
14016+
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
14017+
14018+
cur = build_norm(ffn_inp, mtp_layer.attn_post_norm, NULL, LLM_NORM_RMS, il);
14019+
14020+
// moe ffn for nextn block
14021+
{
14022+
// Process routed experts using existing MoE infrastructure
14023+
ggml_tensor * routed_out = build_moe_ffn(cur,
14024+
mtp_layer.ffn_gate_inp,
14025+
mtp_layer.ffn_up_exps,
14026+
mtp_layer.ffn_gate_exps,
14027+
mtp_layer.ffn_down_exps,
14028+
mtp_layer.ffn_exp_probs_b,
14029+
n_expert, n_expert_used,
14030+
LLM_FFN_SILU, hparams.expert_weights_norm,
14031+
true, hparams.expert_weights_scale,
14032+
(llama_expert_gating_func_type) hparams.expert_gating_func,
14033+
il);
14034+
cb(routed_out, "ffn_moe_out", il);
14035+
14036+
// Process shared expert on original input
14037+
ggml_tensor * shared_out = build_ffn(cur,
14038+
mtp_layer.ffn_up_shexp, NULL, NULL,
14039+
mtp_layer.ffn_gate_shexp, NULL, NULL,
14040+
mtp_layer.ffn_down_shexp, NULL, NULL,
14041+
NULL,
14042+
LLM_FFN_SILU, LLM_FFN_PAR, il);
14043+
cb(shared_out, "ffn_shexp_out", il);
14044+
14045+
// Final output: routed_output + shared_output
14046+
cur = ggml_add(ctx0, routed_out, shared_out);
14047+
cb(cur, "ffn_out", il);
14048+
}
14049+
14050+
cur = ggml_add(ctx0, cur, ffn_inp);
14051+
14052+
cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il);
14053+
cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur);
14054+
14055+
res->t_logits = cur;
14056+
14057+
ggml_build_forward_expand(gf, res->t_logits);
14058+
}
14059+
};
14060+
1392214061
struct llm_build_nemotron : public llm_graph_context {
1392314062
llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
1392414063
const int64_t n_embd_head = hparams.n_embd_head_v;

tools/server/server.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,7 +1432,7 @@ struct server_slot {
14321432
}
14331433

14341434
bool can_speculate() const {
1435-
return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt;
1435+
return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt;
14361436
}
14371437

14381438
void add_token(const completion_token_output & token) {
@@ -2122,14 +2122,16 @@ struct server_context {
21222122
common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
21232123
}
21242124
}
2125+
2126+
// if model has MTP and no draft model is specified...
21252127
else if (llama_model_n_nextn_layer(model) > 0) {
2126-
SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model));
2127-
slot.has_mtp = true;
2128-
2129-
// assume one speculative token (true of all well-known MTP models so far)
2130-
slot.batch_spec = llama_batch_init(2, 0, 1);
2131-
params_base.speculative.n_min = 0;
2132-
params_base.speculative.n_max = 1;
2128+
SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model));
2129+
slot.has_mtp = true;
2130+
2131+
// assume one speculative token (true of all well-known MTP models so far)
2132+
slot.batch_spec = llama_batch_init(2, 0, 1);
2133+
params_base.speculative.n_min = 0;
2134+
params_base.speculative.n_max = 1;
21332135
}
21342136

21352137
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);

0 commit comments

Comments
 (0)