Skip to content

Add GLM 5 MTP#1513

Open
SamuelOliveirads wants to merge 35 commits intoikawrakow:mainfrom
SamuelOliveirads:feat/glm5-mtp
Open

Add GLM 5 MTP#1513
SamuelOliveirads wants to merge 35 commits intoikawrakow:mainfrom
SamuelOliveirads:feat/glm5-mtp

Conversation

@SamuelOliveirads
Copy link
Copy Markdown
Contributor

Add mtp support for GLM-5, to try use the args -mtp to activate and --draft-max, --draft-p-min to control how much tokens you want to generate.

Test's applied

  1. Test 1: Write a quick sort python algorithm, answer only the code.
  2. Test 2: Extract all core events with their exact dates into a bulleted list I copied the "Top" YouTube section from Wikipedia: https://en.wikipedia.org/wiki/YouTube#
  3. Test 3: Write an unexpected short story about someone exploring a cyberpunk city in 2077, but the main character's internal dialogue is deeply analytical and philosophical.

GLM 5 smol-IQ2_KS - Draft size = 10, p-min = 0.85, -ot "blk.78..*=CUDA1", --seed 42

Without MTP vs With MTP

Prompt Baseline (ts) MTP (ts) Accept Rate (%) Difference (%)
Quicksort python 8.18 6.90 62.2% -15.65%
Test reasoning 8.23 5.36 57.8% -34.87%
Creative writing 8.13 5.11 50.8% -37.15%

Ports the Multi-Token Prediction (MTP) architecture to the older `llama.cpp` codebase used by `ikllama`.

Changes include:
- Updating `llama_batch` to support `mtp_params`.
- Modifying `llama_decode_internal` (and `encode`) to handle MTP operations (Warmup, Update, Draft).
- Adding public APIs for MTP state management (`llama_set_draft_input_hidden_state`).
- Adapting the embedding extraction logic to skip MTP update passes.
@ikawrakow
Copy link
Copy Markdown
Owner

I think it would be better to first achieve performance improvement via MTP before adding MTP for more models.

@jukofyork
Copy link
Copy Markdown
Contributor

jukofyork commented Mar 26, 2026

Have you tried -mla 1 (assuming you used -mla 3)?

I found that Kimi-K2 got no improvement using -mla 3 using my draft model, but did get a ~20% improvement using -mla 1 so long as you used --draft-min 2.

I never investigated too deeply why as now moved on to using Kimi-K2.5 which doesn't work well with temperature=0 and when I looked at the code there shouldn't actually be any difference between -mla 1 and -mla 3 for TG speed (eg: -mla 3 has a threshold for "decompression" of a value way higher than any draft could be).

Maybe worth a try though.

@magikRUKKOLA
Copy link
Copy Markdown

@jukofyork

Have you tried to use the Qwen3.5-35B as a draft for Qwen3.5-397B ? I remember doing with older models that and having decent speedup.

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Mar 26, 2026

@SamuelOliveirads

/opt/ik_llama.cpp/ik_llama.cpp/src/llama.cpp:4087: GGML_ASSERT(lctx.embd != nullptr) failed                     
                                                                          
Using host libthread_db library "/usr/lib/x86_64-linux-gnu/libthread_db.so.1".                                              
0x00007fdd8d6a76be in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6                                                        
#0  0x00007fdd8d6a76be in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6                                                    
#1  0x00007fdd8d69be64 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6                                                    
#2  0x00007fdd8d69bead in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6                                                    
#3  0x00007fdd8d707c07 in wait4 () from /usr/lib/x86_64-linux-gnu/libc.so.6                                                 
#4  0x00007fdd8dd37de8 in ggml_abort () from /opt/ik_llama.cpp/ik_llama.cpp/build/ggml/src/libggml.so                       
#5  0x00007fdd9c4b047a in llama_decode () from /opt/ik_llama.cpp/ik_llama.cpp/build/src/libllama.so                         
#6  0x000055dfec6a56f2 in llama_init_from_gpt_params(gpt_params&) ()                                                        
#7  0x000055dfec5aaac2 in server_context::load_model(gpt_params const&) ()                                                  
#8  0x000055dfec4daedb in main ()                             
[Inferior 1 (process 3369289) detached]                       

/opt/ik_llama.cpp/ik_llama.cpp/src/llama.cpp:4087: GGML_ASSERT(lctx.embd != nullptr) failed
[Detaching after fork from child process 3370107]
❌️ warning: process 3369928 is already traced by process 3369918
ptrace: Operation not permitted.
❌️ No stack.
❌️ The program is not being run.

Thread 1 "llama-server" received signal SIGABRT, Aborted.
0x00007fffe8ea13bc in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6
(gdb) bt full
#0  0x00007fffe8ea13bc in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6
No symbol table info available.
#1  0x00007fffe8e4a942 in raise () from /usr/lib/x86_64-linux-gnu/libc.so.6
No symbol table info available.
#2  0x00007fffe8e324ac in abort () from /usr/lib/x86_64-linux-gnu/libc.so.6
No symbol table info available.
#3  0x00007fffe9537dfd in ggml_abort () from /opt/ik_llama.cpp/ik_llama.cpp/build/ggml/src/libggml.so
No symbol table info available.
#4  0x00007ffff7cb047a in llama_decode () from /opt/ik_llama.cpp/ik_llama.cpp/build/src/libllama.so
No symbol table info available.
#5  0x00005555557d56f2 in llama_init_from_gpt_params(gpt_params&) ()
No symbol table info available.
#6  0x00005555556daac2 in server_context::load_model(gpt_params const&) ()
No symbol table info

@jukofyork
Copy link
Copy Markdown
Contributor

@jukofyork

Have you tried to use the Qwen3.5-35B as a draft for Qwen3.5-397B ? I remember doing with older models that and having decent speedup.

I'm really only using Kimi-K2.5 now (at the full 256k context) as have it running rock solid in opencode and everything else is worse and/or messes up tool calls.

@magikRUKKOLA
Copy link
Copy Markdown

@jukofyork

Are there any standardized tests to check the scores of the LLM regarding the tool-call performance etc. that can be ran locally?

@jukofyork
Copy link
Copy Markdown
Contributor

@jukofyork

Are there any standardized tests to check the scores of the LLM regarding the tool-call performance etc. that can be ran locally?

Not really, but you will very quickly find out if it starts hallucinating the tool calls in the chat (deepseek is particularly bad and fails almost instantly).

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Mar 26, 2026

@jukofyork

Okay what about using smol-IQ1_KT fully GPU-offloaded as a draft for a larger quant with only offloaded head and the KV-cache?

Having about 31 tps decode at zero ctx and 21 tps at 32k ctx. [EDIT]: naaah. I don't think its worth it. 21 tps at 32k ctx is already slow enough. Hmm... I should probably finally try with double EPYC.

@SamuelOliveirads
Copy link
Copy Markdown
Contributor Author

I think it would be better to first achieve performance improvement via MTP before adding MTP for more models.

@ikawrakow To be honest, I already had the GLM5 and use it fairly often, so I wanted to add it to have a point of comparison. As for other MTPs, I don’t plan on adding them for now, especially since we don’t retain the layer and it’s unlikely anyone would want to re-quantize just to test a slow feature.


Have you tried -mla 1 (assuming you used -mla 3)?

@jukofyork With MLA 1 or 3 I saw slightly lower performance, for me the best performance was: no MLA > MLA3 > MLA1. To be honest, I haven’t been fine-tuning the arguments for a while, but since you mentioned -draft-min, I have an idea in mind that might help better define that parameter, I’ll see how it works in practice later.


/opt/ik_llama.cpp/ik_llama.cpp/src/llama.cpp:4087: GGML_ASSERT(lctx.embd != nullptr) failed                     

@magikRUKKOLA Could you give me some details about the arguments used? I tested it with Kimi K2.5, thinking there was an incompatibility with MTP, then I tested it with GLM5 without MTP and didn't get any errors.

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Mar 26, 2026

@SamuelOliveirads

Could you give me some details about the arguments used?

/opt/ik_llama.cpp/ik_llama.cpp/build/bin/llama-server \
    --model /opt/ubergarm/GLM-5-GGUF/smol-IQ2_KS/GLM-5-smol-IQ2_KS-00001-of-00006.gguf \
    --alias ubergarm/GLM-5-smol-IQ2_KS \
    --ctx-size $((128 * 1024)) \
    -b $((1024)) -ub $((1024)) \
    --mlock \
    --temp 0.0 --top-p 1.0 --top-k 0 \
    -ctk q6_0 \
    -ctv q6_0 \
    -mtp \
    -khad \
    -ger \
    -smgs \
    -sas \
    -muge \
    -mea 16 \
    -amb 16 \
    --merge-qkv \
    --graph-reduce-type bf16 \
    --split-mode layer \
    --main-gpu 0 \
    --max-gpu 0 \
    --n-gpu-layers 99 \
    --threads $(grep ^cpu\\scores /proc/cpuinfo | uniq | awk '{print $4}' | xargs -I{} echo "{}-0" | bc) \
    --host 0.0.0.0 \
    --port 8080 \
    --log-enable \
    --logdir /var/log/ \
    --jinja \
    --special \
    --verbosity 1 \
    --verbose-prompt \
    --reasoning-format auto \
    --prompt-cache "$HOME/.cache/ik_llama.cpp/prompt-cache.bin" --prompt-cache-all \
    --slot-save-path "$HOME/.cache/ik_llama.cpp/slot.bin" \
    --lookup-cache-dynamic "$HOME/.cache/ik_llama.cpp/slot.bin" \
    --keep -1 \
    --slot-prompt-similarity 0.35 \
    --metrics \
    -cuda fusion=1

[EDIT]: woops. I had to use --threads 1. But that would not matter much anyway.

@jukofyork
Copy link
Copy Markdown
Contributor

@jukofyork With MLA 1 or 3 I saw slightly lower performance, for me the best performance was: no MLA > MLA3 > MLA1. To be honest, I haven’t been fine-tuning the arguments for a while, but since you mentioned -draft-min, I have an idea in mind that might help better define that parameter, I’ll see how it works in practice later.

On mainline llama.cpp I found the best thing to do is run a sweep of all batch sizes from 1 to 64 and plot them. You often see things in the first 2-8 batch sizes that help tune the draft parameters.

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Mar 26, 2026

@SamuelOliveirads

[EDITED]:

/opt/ubergarm/Kimi-K2.5-GGUF/smol-IQ1_KT:

WARN [              load_model] WARNING: -mtp flag provided, but model has 0 NextN layers. MTP will be disabled.

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Mar 26, 2026

GLM 5 smol-IQ2_KS - Draft size = 10, p-min = 0.85, -ot "blk.78..*=CUDA1", --seed 42

What arguments should I use once again? How to set the draft size ?

[EDIT]: Oh. I see. So via the --draft-max which is 16 by default.

@magikRUKKOLA
Copy link
Copy Markdown

I tested it with Kimi K2.5, thinking there was an incompatibility with MTP, then I tested it with GLM5 without MTP and didn't get any errors.

Its with -mtp provided for GLM5.

@SamuelOliveirads
Copy link
Copy Markdown
Contributor Author

@magikRUKKOLA I wasn't able to reproduce the same error with your arguments, the only difference was that I couldn't fully offload to the GPU with such a large model. That said, there were some errors that occurred, and they were fixed after the most recent rebase of the branch. Since your first test was done before that, please try making a new pull.

To provide more context, the models that have MTP and support it are GLM 4.5/4.6/4.7 and 5.0. You can try running the -mtp command with any other model, and it will be disabled (I used Kimi K2.5 as a test to see if this logic was causing your crash before).

Currently, MTP only supports --draft-max and --draft-p-min


On mainline llama.cpp I found the best thing to do is run a sweep of all batch sizes from 1 to 64 and plot them. You often see things in the first 2-8 batch sizes that help tune the draft parameters.

@jukofyork I believe that certain parameters, such as draft-max, draft-min, and p-min, could be optimized, perhaps using a controller that can adjust the parameters based on the hit rate of the speculative models. Since you’re running some tests, are there any parameters you’d like me to test?

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Mar 26, 2026

@SamuelOliveirads

Aha! Yes, it does not crash indeed.
With -mtp its a lot slower. I will publish the results for the first test.

Its like molasses, yeah.

VERB [speculative_decoding_accept] speculative decoding result | tid="140439775965184" timestamp=1774557499 id_slot=0 accepted=1 total=0 new_n_past=1562
VERB [            update_slots] run slots completed | tid="140439775965184" timestamp=1774557499
VERB [              start_loop] wait for new task | tid="140439775965184" timestamp=1774557499
VERB [              start_loop] new task may arrive | tid="140439775965184" timestamp=1774557499
slot print_timing: id  0 | task 184 | 
prompt eval time =      56.74 ms /     1 tokens (   56.74 ms per token,    17.62 tokens per second)
       eval time =  129752.55 ms /  1538 tokens (   84.36 ms per token,    11.85 tokens per second)
      total time =  129809.29 ms /  1539 tokens
VERB [              start_loop] update_multitasks | tid="140439775965184" timestamp=1774557499
draft acceptance rate = 0.57330 (  786 accepted /  1371 generated)

without -mtp:

prompt eval time =     600.34 ms /    24 tokens (   25.01 ms per token,    39.98 tokens per second)
       eval time =   62964.02 ms /  1575 tokens (   39.98 ms per token,    25.01 tokens per second)
      total time =   63564.37 ms /  1599 tokens

Overall, with -mtp its about 2 times slower decode.

@SamuelOliveirads
Copy link
Copy Markdown
Contributor Author

Overall, with -mtp its about 2 times slower decode.

Don't worry, one day it will be optimized enough to be worth it (I hope).

@magikRUKKOLA
Copy link
Copy Markdown

@SamuelOliveirads

Should I re-try with hybrid inference?

@jukofyork
Copy link
Copy Markdown
Contributor

jukofyork commented Mar 26, 2026

@magikRUKKOLA I wasn't able to reproduce the same error with your arguments, the only difference was that I couldn't fully offload to the GPU with such a large model. That said, there were some errors that occurred, and they were fixed after the most recent rebase of the branch. Since your first test was done before that, please try making a new pull.

To provide more context, the models that have MTP and support it are GLM 4.5/4.6/4.7 and 5.0. You can try running the -mtp command with any other model, and it will be disabled (I used Kimi K2.5 as a test to see if this logic was causing your crash before).

Currently, MTP only supports --draft-max and --draft-p-min

On mainline llama.cpp I found the best thing to do is run a sweep of all batch sizes from 1 to 64 and plot them. You often see things in the first 2-8 batch sizes that help tune the draft parameters.

@jukofyork I believe that certain parameters, such as draft-max, draft-min, and p-min, could be optimized, perhaps using a controller that can adjust the parameters based on the hit rate of the speculative models. Since you’re running some tests, are there any parameters you’d like me to test?

See the posts in this thread, starting here:

ggml-org/llama.cpp#10466 (comment)

I tried to simplify it to the bare minimum here:

ggml-org/llama.cpp#17034

but nobody seemed interested and mainline llama.cpp speculative decoding logic keeps getting more and more complex, so not really sure if I can revive it now.

The key thing from all my experiments is that you can't really just use a fixed min-p as there are all sorts of weird jumps in the batch costs depending on the backend(s) used, FA optimisations, MMQ thresholds, and so on... You really have to consider the sequence probabilities and batch costs for each batch size to get it working well:

Figure_1

Some kind of adaptive controller would be the next step, but there was pretty much zero interest in that discussion and PR...


I'm also not convinced the current logic is correct:

ggml-org/llama.cpp#10466 (comment)

The code has got so many tricky optimisations in it now though, but I think you can show that if batch=2 has > 2 × batch=1 we should never actually use batch=2, but the state of the code when I made that post meant you always would try batch=2 even if the single token you saw before breaking from the loop had a super low probability.

If you look at the costs for my GLM-4.6 in the graph above, it never makes sense to try batch=2 as it is slower than just running batch=1 twice.

@SamuelOliveirads
Copy link
Copy Markdown
Contributor Author

Should I re-try with hybrid inference?

@magikRUKKOLA If you want to test whether the GLM5 MTP code works, go ahead I appreciate it, but in terms of performance, it shouldn't make much of a difference.


See the posts in this thread, starting here:

ggml-org/llama.cpp#10466 (comment)

I tried to simplify it to the bare minimum here:

ggml-org/llama.cpp#17034

@jukofyork This is a great material, I need more time to read through the details, but I’ll definitely use it when I start working on this feature. I believe parameter inferences can be made in real time, which allows for adapting the settings to the user’s needs and use cases. At the end of the session, a snapshot of the current metrics could be provided so that the user can use it as a default in the future if they wish.

@magikRUKKOLA
Copy link
Copy Markdown

magikRUKKOLA commented Mar 27, 2026

@SamuelOliveirads

GLM5 IQ2_KL --cpu-moe:

without -mtp:

prompt eval time =    1776.41 ms /    24 tokens (   74.02 ms per token,    13.51 tokens per second)
       eval time =  111475.06 ms /  1211 tokens (   92.05 ms per token,    10.86 tokens per second)
      total time =  113251.47 ms /  1235 tokens

with -mtp:

prompt eval time =    1308.86 ms /    24 tokens (   54.54 ms per token,    18.34 tokens per second)
       eval time =  192046.45 ms /  1437 tokens (  133.64 ms per token,     7.48 tokens per second)
      total time =  193355.32 ms /  1461 tokens
VERB [speculative_decoding_accept] speculative decoding result | tid="139636109680640" timestamp=1774596621 id_slot=0 accepted=1 total=0 new_n_past=1461
draft acceptance rate = 0.57460 (  751 accepted /  1307 generated)


@SamuelOliveirads
Copy link
Copy Markdown
Contributor Author

GLM5 IQ2_KL --cpu-moe:

without -mtp:

prompt eval time =    1776.41 ms /    24 tokens (   74.02 ms per token,    13.51 tokens per second)
       eval time =  111475.06 ms /  1211 tokens (   92.05 ms per token,    10.86 tokens per second)
      total time =  113251.47 ms /  1235 tokens

with -mtp:

prompt eval time =    1308.86 ms /    24 tokens (   54.54 ms per token,    18.34 tokens per second)
       eval time =  192046.45 ms /  1437 tokens (  133.64 ms per token,     7.48 tokens per second)
      total time =  193355.32 ms /  1461 tokens
VERB [speculative_decoding_accept] speculative decoding result | tid="139636109680640" timestamp=1774596621 id_slot=0 accepted=1 total=0 new_n_past=1461
draft acceptance rate = 0.57460 (  751 accepted /  1307 generated)

The performance loss is consistent with my tests, which leads me to believe that the initial gains will be in hybrid/CPU-only inference, but that in the future the main gains will come from the GPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants