Skip to content

Reduce MTP operations#1531

Open
SamuelOliveirads wants to merge 7 commits intoikawrakow:mainfrom
SamuelOliveirads:feat/reduce-mtp-op
Open

Reduce MTP operations#1531
SamuelOliveirads wants to merge 7 commits intoikawrakow:mainfrom
SamuelOliveirads:feat/reduce-mtp-op

Conversation

@SamuelOliveirads
Copy link
Copy Markdown
Contributor

Currently, the MTP workflow consists of three operations, each of which performs tasks such as updating the VK cache or the draft token. A typical workflow would be:

Main Model prompt processing -> MTP prompt processing -> mtp draft -> main model validation -> mtp kv update

The problem is that each operation makes a llama_decode call. This PR removes the last call to update the MTP KV cache, named MTP_OP_UPDATE_ACCEPTED, and combines it with the same operation in the MTP draft, allowing the operations to be performed together.

This simplification makes maintenance easier, aligns with the approach adopted for the main model, and reduces the overhead that an extra llama_decode call could generate, a benefit that is particularly noticeable in models like GLM 4.5 Air.

The tests follow the same pattern as PR #1499

GLM 4.5 Air IQ4_XS | -ot "blk.46..*=CUDA1", --seed 42

MTP Performance

Prompt PR #1499 (ts) Current branch (ts) Difference (%)
Quicksort python 12.07 14.35 +18.89%
Test reasoning 9.79 10.84 +10.73%
Creative writing 8.18 11.06 +35.21%

I didn't notice any performance gains in GLM 4.7, although I believe the GPU-only version shows some improvement.

@magikRUKKOLA
Copy link
Copy Markdown

git merge pr-1513
git merge pr-1530
git merge pr-1531

?

/opt/ik_llama.cpp/ik_llama.cpp/src/llama-build-context.cpp: In member function ‘ggml_cgraph* llm_build_context::build_deepseek2()’:
/opt/ik_llama.cpp/ik_llama.cpp/src/llama-build-context.cpp:6831:76: error: ‘MTP_OP_UPDATE_ACCEPTED’ was not declared in this scope
 6831 |         if (cparams.mtp_op_type == MTP_OP_WARMUP || cparams.mtp_op_type == MTP_OP_UPDATE_ACCEPTED) {
      |                                                                            ^~~~~~~~~~~~~~~~~~~~~~
gmake[2]: *** [src/CMakeFiles/llama.dir/build.make:177: src/CMakeFiles/llama.dir/llama-build-context.cpp.o] Error 1
gmake[1]: *** [CMakeFiles/Makefile2:2059: src/CMakeFiles/llama.dir/all] Error 2
gmake: *** [Makefile:146: all] Error 2

@magikRUKKOLA
Copy link
Copy Markdown

@SamuelOliveirads

Tried to declare it, as 3 but havign the following:

/opt/ik_llama.cpp/ik_llama.cpp/ggml/src/ggml.c:6901: GGML_ASSERT(a->ne[d] == b->ne[d]) failed      

@SamuelOliveirads
Copy link
Copy Markdown
Contributor Author

git merge pr-1513
git merge pr-1530
git merge pr-1531

?

/opt/ik_llama.cpp/ik_llama.cpp/src/llama-build-context.cpp: In member function ‘ggml_cgraph* llm_build_context::build_deepseek2()’:
/opt/ik_llama.cpp/ik_llama.cpp/src/llama-build-context.cpp:6831:76: error: ‘MTP_OP_UPDATE_ACCEPTED’ was not declared in this scope
 6831 |         if (cparams.mtp_op_type == MTP_OP_WARMUP || cparams.mtp_op_type == MTP_OP_UPDATE_ACCEPTED) {
      |                                                                            ^~~~~~~~~~~~~~~~~~~~~~
gmake[2]: *** [src/CMakeFiles/llama.dir/build.make:177: src/CMakeFiles/llama.dir/llama-build-context.cpp.o] Error 1
gmake[1]: *** [CMakeFiles/Makefile2:2059: src/CMakeFiles/llama.dir/all] Error 2
gmake: *** [Makefile:146: all] Error 2

Since PR 1513 hasn't been merged yet, I can't apply the changes from that PR to it, but if you want to test it, you'll need to remove the reference to the deprecated OP from the graph:

diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp
index 5880830b..ecbbcdec 100644
--- a/src/llama-build-context.cpp
+++ b/src/llama-build-context.cpp
@@ -7699,15 +7699,11 @@ ggml_cgraph * llm_build_context::build_glm4_moe() {
     if (cparams.mtp_op_type != MTP_OP_NONE) {
         ggml_tensor* hidden_states_from_main_model;
 
-        if (cparams.mtp_op_type == MTP_OP_WARMUP || cparams.mtp_op_type == MTP_OP_UPDATE_ACCEPTED) {
-            hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
-        } else {
-            hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd);
-        }
+        hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
         ggml_set_name(hidden_states_from_main_model, "result_embd_pooled");
         ggml_set_input(hidden_states_from_main_model);

        lctx.inp_mtp_states = hidden_states_from_main_model;
 
         const int il_mtp = hparams.n_layer - 1;
         const auto & mtp_layer = model.layers[il_mtp];

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Mar 28, 2026

@SamuelOliveirads

GLM-5; IQ2_KL

CPU/GPU

Test #1, with -mtp.

prompt eval time =    1238.49 ms /    24 tokens (   51.60 ms per token,    19.38 tokens per second)
       eval time =  267141.41 ms /  2094 tokens (  127.57 ms per token,     7.84 tokens per second)
      total time =  268379.91 ms /  2118 tokens
draft acceptance rate = 0.61006 ( 1092 accepted /  1790 generated)
statistics mtp: #calls(b,g,a) = 1 1001 885, #gen drafts = 1001, #acc drafts = 885, #gen tokens = 1790, #acc tokens = 1092, dur(b,g,a) = 0.000, 12959.003, 0.239 ms

Without `-mtp':

prompt eval time =    1636.49 ms /    24 tokens (   68.19 ms per token,    14.67 tokens per second)
       eval time =  105376.11 ms /  1211 tokens (   87.02 ms per token,    11.49 tokens per second)
      total time =  107012.60 ms /  1235 tokens

GPU

with -mtp:

prompt eval time =     726.22 ms /    24 tokens (   30.26 ms per token,    33.05 tokens per second)
       eval time =  146665.10 ms /  1823 tokens (   80.45 ms per token,    12.43 tokens per second)
      total time =  147391.32 ms /  1847 tokens
VERB [              start_loop] new task may arrive | tid="140592770134016" timestamp=1774666152
draft acceptance rate = 0.62898 (  968 accepted /  1539 generated)
VERB [              start_loop] update_multitasks | tid="140592770134016" timestamp=1774666152
statistics mtp: #calls(b,g,a) = 1 854 772, #gen drafts = 854, #acc drafts = 772, #gen tokens = 1539, #acc tokens = 968, dur(b,g,a) = 0.001, 7746.566, 0.199 ms

without -mtp:

prompt eval time =     575.80 ms /    24 tokens (   23.99 ms per token,    41.68 tokens per second)
       eval time =   62643.84 ms /  1575 tokens (   39.77 ms per token,    25.14 tokens per second)
      total time =   63219.64 ms /  1599 tokens

@magikRUKKOLA
Copy link
Copy Markdown

@SamuelOliveirads

The problem with CUDA possibly related to the fact that my prefill is pretty slow when the model is spread across 10 GPUs?

Here is the test with six x8 and four x4 RTX 3090 (now its eight x8 and two x4 I just have no time to retest ..):

main: n_kv_max = 131072, n_batch = 1024, n_ubatch = 1024, flash_attn = 1, n_gpu_layers = 99, n_threads = 64, n_threads_batch = 64

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
1024 256 0 4.353 235.24 9.928 25.79
1024 256 1024 4.795 213.58 10.140 25.25
1024 256 2048 5.332 192.05 10.561 24.24
1024 256 3072 5.630 181.88 10.697 23.93
1024 256 4096 6.378 160.56 10.923 23.44
1024 256 5120 6.846 149.59 11.026 23.22
1024 256 6144 7.128 143.65 11.114 23.03
1024 256 7168 7.513 136.30 11.162 22.93
1024 256 8192 7.891 129.77 11.600 22.07
1024 256 9216 8.357 122.53 11.752 21.78
1024 256 10240 8.875 115.38 11.818 21.66
1024 256 11264 9.221 111.06 11.870 21.57
1024 256 12288 9.489 107.91 12.276 20.85
1024 256 13312 9.815 104.33 12.445 20.57
1024 256 14336 10.163 100.76 12.529 20.43
1024 256 15360 10.679 95.89 12.579 20.35
1024 256 16384 11.004 93.05 12.980 19.72
1024 256 17408 11.195 91.47 13.157 19.46
1024 256 18432 11.513 88.94 13.233 19.35
1024 256 19456 11.855 86.37 13.282 19.27
1024 256 20480 12.396 82.61 13.650 18.75
1024 256 21504 12.769 80.19 13.840 18.50
1024 256 22528 12.874 79.54 13.900 18.42
1024 256 23552 13.211 77.51 13.958 18.34
1024 256 24576 13.565 75.49 14.327 17.87
1024 256 25600 14.137 72.44 14.528 17.62
1024 256 26624 14.534 70.45 14.591 17.55
1024 256 27648 14.907 68.69 14.652 17.47
1024 256 28672 14.997 68.28 15.012 17.05
1024 256 29696 15.340 66.75 15.207 16.83
1024 256 30720 15.964 64.14 15.265 16.77
1024 256 31744 16.360 62.59 15.337 16.69
1024 256 32768 16.738 61.18 15.670 16.34
1024 256 33792 16.761 61.09 15.875 16.13
1024 256 34816 16.918 60.53 15.943 16.06
1024 256 35840 17.410 58.82 16.017 15.98
1024 256 36864 17.793 57.55 16.385 15.62
1024 256 37888 18.113 56.54 16.569 15.45
1024 256 38912 18.137 56.46 16.809 15.23
1024 256 39936 17.713 57.81 16.711 15.32
1024 256 40960 18.349 55.81 17.063 15.00
1024 256 41984 19.233 53.24 17.263 14.83
1024 256 43008 19.387 52.82 17.361 14.75
1024 256 44032 19.725 51.91 17.409 14.71

@SamuelOliveirads
Copy link
Copy Markdown
Contributor Author

The problem with CUDA possibly related to the fact that my prefill is pretty slow when the model is spread across 10 GPUs?

Without MTP, that might be the case, but with MTP, the bottleneck involves switching between the main model and the MTP. Since the MTP graph is different from the main model, we end up recreating operations in the backend all the time.

@ikawrakow
Copy link
Copy Markdown
Owner

So, maybe I don't understand the whole MTP thing very well, but here is how I see it.

Let's say that generating a draft token takes time $T_d$. Let's denote the time it takes to run the main model for a single token with $T$. For simplicity, let's assume that the time it takes to decode a batch of 1 token is about the same as the time it takes to decode a batch with $N$ tokens (where $N$ is small, say less than 8 tokens or so). This is not strictly true, and in some cases a batch of 2 tokens may actually take almost 2 times as long as a single token, but just to make things simpler.

So, then, if we only generate a single draft token, the time it took to generate the draft and then run the main model is $T + T_d$. Let's say that the probability to accept the draft token is $p$. Then, on average, after generating a single draft token we will end up with $p$ tokens (where $p < 1$), so our TG performance will be

$$\frac{p}{T + T_d}$$

If we did not use a draft model at all, the TG performance will be

$$\frac{1}{T}$$

So, our "acceleration" will be

$$p \frac{T}{T + T_d} = \frac{p}{1 + \alpha},\quad\quad \alpha = \frac{T_d}{T}$$

This is of course always less than 1, so that's why I put "acceleration" in quotes. I.e., if we are drafting a single token, we can never ever get an acceleration from MTP. Am I missing something here?

Now, let's look at what happens if we drafted 2 tokens. Let's also assume that we are not actually stopping after the 1st draft token to see if it will get accepted, we just generate 2 draft tokens in a single shot. This is possible because we are using greedy sampling for the draft tokens anyway, so we may as well incorporate the sampling into the draft generation. If we did that, we will spend $2 T_d$ to generate the draft, and $T$ to decode the 2 draft tokens in a single batch (as per above simplifying assumption). We will then end up with

  • With probability $1 - p$ we all reject the 1st draft token, so we end up with zero generated tokens.
  • With probability $p (1 - p)$ we will accept the 1st draft token, so we end up with 1 generated token
  • With probability $p^2$ we will accept both draft tokens, and will end up with 2 generated tokens.
    Hence, on average, we will have $p (1 - p) + 2 p^2 = p (1 + p)$ generated tokens. Our TG performance will be

$$\frac{p (1 + p)}{T + 2 T_d}$$

and our acceleration compared to no draft will be

$$\frac{p (1 + p)}{1 + 2 \alpha}$$

This will be an actual acceleration if it is greater than 1, which is satisfied if

$$\alpha < \frac{p (1 + p) - 1}{2}$$

Example: if draft acceptance rate is $p = 0.7$, then we need $\alpha &lt; 0.095$, i.e., generating a draft token should not take more than 9.5% of the time it takes to decode a token with the full model. Also, in order to be able to achieve acceleration at all (even for an infinitely fast draft generator), we need $p (1 + p) - 1 &gt; 0$, so $p &gt; (\sqrt{5}-1)/2 = 0.618$. I.e., if we are observing draft acceptance rate of less than 62%, we might as well go home and do something else.

We can generalize the above to the case of generating $N$ draft tokens, all at once. We end up with acceleration given by

$$\frac{1 - p^N}{1 -p}~~\frac{p}{1 + N \alpha}$$

To be greater than 1, this requires

$$\alpha < \frac{1}{N} \left( p \frac{1 - p^N}{1 - p} - 1 \right)$$

OK, based on this analysis, this is what I would try to do:

  • The MTP draft generating function should take the number $N$ of draft tokens to be generated as argument
  • When building the graph, there will be a loop over $N$ tokens. In each loop iteration, a graph for generating 1 tokens is built, this is followed by ggml_argmax, which selects the highest probability token, followed by ggml_get_rows with the selected token, which becomes the input for the next iteration.

In that way we end up with $N$ draft tokens with a single graph evaluation. The generated draft tokens become the input for a batch of $N$ tokens that we evaluate with the main model. Only after we have done both of these things, we start sampling tokens, stopping the iteration if the sampled token was not the same as the draft token.

@SamuelOliveirads
Copy link
Copy Markdown
Contributor Author

@ikawrakow Your theory seems correct, and regarding your suggestions, I’ll share my interpretation and questions, see if they make sense.

The workflow you described is what we have today, but instead of delegating operations to manage hidden state and logits outside of llama_decode, the suggestion is to do it in a single graph. The potential benefits I see involve:

  1. Reducing bottlenecks in copying and transferring information, especially from CPU to GPU.
  2. Setting N batch to optimize the number of graphs to be reused by the graph-reuse and scheduler.

Today we can implement the logic of N tokens using arguments --draft-max N --draft-p-min 0.0, and looking at logic implemented in SGlang, it seems more efficient to me than allowing an early stop even when that means we’ll compute all N tokens even if some may have low confidence.

I wrote a draft to see how this would look in code (I haven’t tested it yet), and it looks something like this:

struct ggml_tensor * llm_build_context::build_mtp_tail_unrolled(
    const llama_layer & mtp_layer,
    struct ggml_tensor * hidden_state_input,
    int64_t n_embd_head,
    struct ggml_cgraph * gf,
    int n_draft
) {
    const int il = hparams.n_layer - 1;

    ggml_tensor * mtp_embd_weights = mtp_layer.nextn.embed_tokens;
    if (mtp_embd_weights == nullptr) {
        mtp_embd_weights = model.tok_embd;
    }

    ggml_tensor * mtp_head_weights = mtp_layer.nextn.shared_head_head;
    if (mtp_head_weights == nullptr) {
        mtp_head_weights = model.output;
    }

    ggml_tensor * inp_pos_all = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_draft);
    ggml_set_name(inp_pos_all, "inp_pos_draft");
    ggml_set_input(inp_pos_all);

    const int32_t n_batch_pad = GGML_PAD(1, GGML_KQ_MASK_PAD);
    std::vector<ggml_tensor *> iter_masks(n_draft);
    for (int iter = 0; iter < n_draft; iter++) {
        const int32_t iter_n_kv = n_kv + iter;
        if (flash_attn) {
            iter_masks[iter] = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, iter_n_kv, n_batch_pad);
        } else {
            iter_masks[iter] = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, iter_n_kv, n_batch_pad);
        }
        ggml_set_input(iter_masks[iter]);
    }

    // First iteration: token embedding from CPU input
    ggml_tensor * token_emb = build_inp_embd_mtp(mtp_embd_weights);
    ggml_tensor * prev_hidden = hidden_state_input;

    ggml_tensor * all_logits = nullptr;

    for (int iter = 0; iter < n_draft; iter++) {
        ggml_tensor * iter_pos = ggml_view_1d(ctx0, inp_pos_all, 1,
            (int64_t)iter * ggml_element_size(inp_pos_all));

        // Embedding + Hidden State
        ggml_tensor * token_emb_norm = llm_build_norm(ctx0, token_emb, hparams,
            mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, cb, il);
        ggml_tensor * hidden_state_norm = llm_build_norm(ctx0, prev_hidden, hparams,
            mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, cb, il);

        ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0);
        cb(combined, "mtp_concat", il);
        ggml_tensor * cur = llm_build_lora_mm(lctx, ctx0, mtp_layer.nextn.eh_proj, combined);

        struct ggml_tensor * inpSA = cur;

        cur = llm_build_norm(ctx0, cur, hparams, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, cb, il);
        cb(cur, "attn_norm", il);

        // Self-Attention
        {
            auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
                    nullptr, nullptr,
                    nullptr, nullptr,
                    mtp_layer.wq, mtp_layer.bq,
                    mtp_layer.wk, mtp_layer.bk,
                    mtp_layer.wv, mtp_layer.bv,
                    mtp_layer.attn_q_norm, mtp_layer.attn_k_norm,
                    0.f, il);

            // RoPE: its possible to use rope_cache here since it's precomputed?
            Qcur = ggml_rope_ext(ctx0, Qcur, iter_pos, nullptr,
                n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                ext_factor, attn_factor, beta_fast, beta_slow);
            Kcur = ggml_rope_ext(ctx0, Kcur, iter_pos, nullptr,
                n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                ext_factor, attn_factor, beta_fast, beta_slow);

            cb(Qcur, "Qcur", il);
            cb(Kcur, "Kcur", il);
            cb(Vcur, "Vcur", il);

            cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                            model.layers[il].wo, NULL,
                            Kcur, Vcur, Qcur, iter_masks[iter],
                            1,                  // n_tokens = 1 per iteration
                            kv_head + iter,     // write at incrementing KV positions
                            n_kv + iter,        // attend to incrementing KV range
                            1.0f/sqrtf(float(n_embd_head)), cb, il);
        }

        // Residual + FFN
        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
        cb(ffn_inp, "mtp_ffn_inp", il);

        cur = llm_build_norm(ctx0, ffn_inp, hparams,
            mtp_layer.attn_post_norm, NULL, LLM_NORM_RMS, cb, il);
        cb(cur, "attn_post_norm", il);

        // MoE FFN
        {
            ggml_tensor * routed_out = llm_build_std_moe_ffn(ctx0, lctx,
                            NULL, cur,
                            mtp_layer.ffn_gate_inp,  NULL,
                            mtp_layer.ffn_up_exps,   NULL,
                            mtp_layer.ffn_gate_exps, NULL,
                            mtp_layer.ffn_down_exps, NULL,
                            mtp_layer.ffn_exp_probs_b,
                            nullptr, nullptr,
                            nullptr, nullptr,
                            nullptr, nullptr,
                            n_expert, n_expert_used,
                            LLM_FFN_SILU, hparams.expert_weights_norm, true,
                            hparams.expert_weights_scale,
                            (llm_expert_gating_func_type) hparams.expert_gating_func,
                            LLM_FFN_SILU, cb, il, gf, true, mtp_layer.ffn_up_gate_exps);
            cb(routed_out, "ffn_moe_out", il);

            ggml_tensor * shared_out = llm_build_ffn(ctx0, lctx,
                            NULL, cur,
                            mtp_layer.ffn_up_shexp,   NULL, NULL,
                            mtp_layer.ffn_gate_shexp, NULL, NULL,
                            mtp_layer.ffn_down_shexp, NULL, NULL,
                            NULL,
                            LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, true);
            cb(shared_out, "ffn_shexp_out", il);

            cur = ggml_add(ctx0, routed_out, shared_out);
            cb(cur, "ffn_out", il);

            cur = ggml_add(ctx0, cur, ffn_inp);
            cb(cur, "mtp_ffn_out_resid", il);
        }

        // Hidden state for next iteration
        prev_hidden = cur;

        //LM Head
        cur = llm_build_norm(ctx0, cur, hparams,
            mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, cb, il);
        cb(cur, "result_norm", -1);

        ggml_tensor * logits = llm_build_lora_mm(lctx, ctx0, mtp_head_weights, cur);

        // Collect all iterations' logits
        if (all_logits == nullptr) {
            all_logits = logits;
        } else {
            all_logits = ggml_concat(ctx0, all_logits, logits, 1);
        }

        if (iter < n_draft - 1) {
            ggml_tensor * next_token_id = ggml_argmax(ctx0, logits);
            token_emb = ggml_get_rows(ctx0, mtp_embd_weights, next_token_id);
        }
    }

    return all_logits;
}

Two questions come to mind:

  1. Specifically for the KV cache, will the graph be able to correctly order the positions? I wonder if ggml won’t interpret this as “these two subgraphs don’t share any tensors, so I’ll parallelize them,” which could break the data filling logic.
  2. From what I understand, you can’t pre-compute the rope_cache since each interaction has a different position, right?

By the way, I'm not sure if your suggestion refers to an improvement to the MTP that should be addressed in a future PR, or if you're thinking of this PR as a way to reduce MTP operations, such as implementing all the logic in a single graph to avoid multiple llama_decode calls. MTP still needs to update its KV cache for prompt processing, so I’m viewing your suggestion more as a separate feature.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants