Conversation
? |
|
Tried to declare it, as |
Since PR 1513 hasn't been merged yet, I can't apply the changes from that PR to it, but if you want to test it, you'll need to remove the reference to the deprecated OP from the graph: diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp
index 5880830b..ecbbcdec 100644
--- a/src/llama-build-context.cpp
+++ b/src/llama-build-context.cpp
@@ -7699,15 +7699,11 @@ ggml_cgraph * llm_build_context::build_glm4_moe() {
if (cparams.mtp_op_type != MTP_OP_NONE) {
ggml_tensor* hidden_states_from_main_model;
- if (cparams.mtp_op_type == MTP_OP_WARMUP || cparams.mtp_op_type == MTP_OP_UPDATE_ACCEPTED) {
- hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
- } else {
- hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd);
- }
+ hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
ggml_set_name(hidden_states_from_main_model, "result_embd_pooled");
ggml_set_input(hidden_states_from_main_model);
lctx.inp_mtp_states = hidden_states_from_main_model;
const int il_mtp = hparams.n_layer - 1;
const auto & mtp_layer = model.layers[il_mtp]; |
|
GLM-5; IQ2_KL CPU/GPU Test Without `-mtp': GPU with without |
|
The problem with CUDA possibly related to the fact that my prefill is pretty slow when the model is spread across 10 GPUs? Here is the test with six x8 and four x4 RTX 3090 (now its eight x8 and two x4 I just have no time to retest ..): main: n_kv_max = 131072, n_batch = 1024, n_ubatch = 1024, flash_attn = 1, n_gpu_layers = 99, n_threads = 64, n_threads_batch = 64
|
Without MTP, that might be the case, but with MTP, the bottleneck involves switching between the main model and the MTP. Since the MTP graph is different from the main model, we end up recreating operations in the backend all the time. |
|
So, maybe I don't understand the whole MTP thing very well, but here is how I see it. Let's say that generating a draft token takes time So, then, if we only generate a single draft token, the time it took to generate the draft and then run the main model is If we did not use a draft model at all, the TG performance will be So, our "acceleration" will be This is of course always less than 1, so that's why I put "acceleration" in quotes. I.e., if we are drafting a single token, we can never ever get an acceleration from MTP. Am I missing something here? Now, let's look at what happens if we drafted 2 tokens. Let's also assume that we are not actually stopping after the 1st draft token to see if it will get accepted, we just generate 2 draft tokens in a single shot. This is possible because we are using greedy sampling for the draft tokens anyway, so we may as well incorporate the sampling into the draft generation. If we did that, we will spend
and our acceleration compared to no draft will be This will be an actual acceleration if it is greater than 1, which is satisfied if Example: if draft acceptance rate is We can generalize the above to the case of generating To be greater than 1, this requires OK, based on this analysis, this is what I would try to do:
In that way we end up with |
|
@ikawrakow Your theory seems correct, and regarding your suggestions, I’ll share my interpretation and questions, see if they make sense. The workflow you described is what we have today, but instead of delegating operations to manage hidden state and logits outside of
Today we can implement the logic of N tokens using arguments I wrote a draft to see how this would look in code (I haven’t tested it yet), and it looks something like this: struct ggml_tensor * llm_build_context::build_mtp_tail_unrolled(
const llama_layer & mtp_layer,
struct ggml_tensor * hidden_state_input,
int64_t n_embd_head,
struct ggml_cgraph * gf,
int n_draft
) {
const int il = hparams.n_layer - 1;
ggml_tensor * mtp_embd_weights = mtp_layer.nextn.embed_tokens;
if (mtp_embd_weights == nullptr) {
mtp_embd_weights = model.tok_embd;
}
ggml_tensor * mtp_head_weights = mtp_layer.nextn.shared_head_head;
if (mtp_head_weights == nullptr) {
mtp_head_weights = model.output;
}
ggml_tensor * inp_pos_all = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_draft);
ggml_set_name(inp_pos_all, "inp_pos_draft");
ggml_set_input(inp_pos_all);
const int32_t n_batch_pad = GGML_PAD(1, GGML_KQ_MASK_PAD);
std::vector<ggml_tensor *> iter_masks(n_draft);
for (int iter = 0; iter < n_draft; iter++) {
const int32_t iter_n_kv = n_kv + iter;
if (flash_attn) {
iter_masks[iter] = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, iter_n_kv, n_batch_pad);
} else {
iter_masks[iter] = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, iter_n_kv, n_batch_pad);
}
ggml_set_input(iter_masks[iter]);
}
// First iteration: token embedding from CPU input
ggml_tensor * token_emb = build_inp_embd_mtp(mtp_embd_weights);
ggml_tensor * prev_hidden = hidden_state_input;
ggml_tensor * all_logits = nullptr;
for (int iter = 0; iter < n_draft; iter++) {
ggml_tensor * iter_pos = ggml_view_1d(ctx0, inp_pos_all, 1,
(int64_t)iter * ggml_element_size(inp_pos_all));
// Embedding + Hidden State
ggml_tensor * token_emb_norm = llm_build_norm(ctx0, token_emb, hparams,
mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, cb, il);
ggml_tensor * hidden_state_norm = llm_build_norm(ctx0, prev_hidden, hparams,
mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, cb, il);
ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0);
cb(combined, "mtp_concat", il);
ggml_tensor * cur = llm_build_lora_mm(lctx, ctx0, mtp_layer.nextn.eh_proj, combined);
struct ggml_tensor * inpSA = cur;
cur = llm_build_norm(ctx0, cur, hparams, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
// Self-Attention
{
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
nullptr, nullptr,
nullptr, nullptr,
mtp_layer.wq, mtp_layer.bq,
mtp_layer.wk, mtp_layer.bk,
mtp_layer.wv, mtp_layer.bv,
mtp_layer.attn_q_norm, mtp_layer.attn_k_norm,
0.f, il);
// RoPE: its possible to use rope_cache here since it's precomputed?
Qcur = ggml_rope_ext(ctx0, Qcur, iter_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, iter_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, iter_masks[iter],
1, // n_tokens = 1 per iteration
kv_head + iter, // write at incrementing KV positions
n_kv + iter, // attend to incrementing KV range
1.0f/sqrtf(float(n_embd_head)), cb, il);
}
// Residual + FFN
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "mtp_ffn_inp", il);
cur = llm_build_norm(ctx0, ffn_inp, hparams,
mtp_layer.attn_post_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_post_norm", il);
// MoE FFN
{
ggml_tensor * routed_out = llm_build_std_moe_ffn(ctx0, lctx,
NULL, cur,
mtp_layer.ffn_gate_inp, NULL,
mtp_layer.ffn_up_exps, NULL,
mtp_layer.ffn_gate_exps, NULL,
mtp_layer.ffn_down_exps, NULL,
mtp_layer.ffn_exp_probs_b,
nullptr, nullptr,
nullptr, nullptr,
nullptr, nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm, true,
hparams.expert_weights_scale,
(llm_expert_gating_func_type) hparams.expert_gating_func,
LLM_FFN_SILU, cb, il, gf, true, mtp_layer.ffn_up_gate_exps);
cb(routed_out, "ffn_moe_out", il);
ggml_tensor * shared_out = llm_build_ffn(ctx0, lctx,
NULL, cur,
mtp_layer.ffn_up_shexp, NULL, NULL,
mtp_layer.ffn_gate_shexp, NULL, NULL,
mtp_layer.ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, true);
cb(shared_out, "ffn_shexp_out", il);
cur = ggml_add(ctx0, routed_out, shared_out);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "mtp_ffn_out_resid", il);
}
// Hidden state for next iteration
prev_hidden = cur;
//LM Head
cur = llm_build_norm(ctx0, cur, hparams,
mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "result_norm", -1);
ggml_tensor * logits = llm_build_lora_mm(lctx, ctx0, mtp_head_weights, cur);
// Collect all iterations' logits
if (all_logits == nullptr) {
all_logits = logits;
} else {
all_logits = ggml_concat(ctx0, all_logits, logits, 1);
}
if (iter < n_draft - 1) {
ggml_tensor * next_token_id = ggml_argmax(ctx0, logits);
token_emb = ggml_get_rows(ctx0, mtp_embd_weights, next_token_id);
}
}
return all_logits;
}Two questions come to mind:
By the way, I'm not sure if your suggestion refers to an improvement to the MTP that should be addressed in a future PR, or if you're thinking of this PR as a way to reduce MTP operations, such as implementing all the logic in a single graph to avoid multiple |
Currently, the MTP workflow consists of three operations, each of which performs tasks such as updating the VK cache or the draft token. A typical workflow would be:
The problem is that each operation makes a
llama_decodecall. This PR removes the last call to update the MTP KV cache, namedMTP_OP_UPDATE_ACCEPTED, and combines it with the same operation in the MTP draft, allowing the operations to be performed together.This simplification makes maintenance easier, aligns with the approach adopted for the main model, and reduces the overhead that an extra
llama_decodecall could generate, a benefit that is particularly noticeable in models like GLM 4.5 Air.The tests follow the same pattern as PR #1499
GLM 4.5 Air IQ4_XS | -ot "blk.46..*=CUDA1", --seed 42
MTP Performance
I didn't notice any performance gains in GLM 4.7, although I believe the GPU-only version shows some improvement.