Skip to content

Conversation

g2mt
Copy link
Contributor

@g2mt g2mt commented Jul 25, 2025

Related to #322

This is a port of the speculative decoding function for llama-server from the upstream code base.

Changes:

  • Updated llama-server source code
  • Added several functions needed for speculative decoding.
  • Add prefixes to KV cache tensors to support loading of multiple models

I used Qwen3-235B in this PR.

@saood06
Copy link
Collaborator

saood06 commented Jul 25, 2025

Thank you for doing this. I can test/review/assist if you need.

@saood06
Copy link
Collaborator

saood06 commented Jul 25, 2025

Also are you aware this: https://github.com/ikawrakow/ik_llama.cpp/blob/main/examples/speculative/speculative.cpp exists.

@g2mt
Copy link
Contributor Author

g2mt commented Jul 25, 2025

I got the server to compile, but when loading Qwen 2.5 1.5b with the 0.5b version as the draft, I get this error:

ggml_backend_alloc_ctx_tensors_from_buft: all tensors in the context are already allocated
llama_kv_cache_init: failed to allocate buffer for kv cache
llama_new_context_with_model: llama_kv_cache_init() failed for self-attention cache
llama_init_from_gpt_params: error: failed to create context with model 'Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf'
 ERR [              load_model] failed to load draft model | tid="140650859190528" timestamp=1753420591 model="Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf"

GDB says it occurred in this llama_init_from_gpt_params call:

            llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft);

I wonder if llama_kv_cache_init is unable to load tensors with the same name. I'll try and fix the code later.

@g2mt
Copy link
Contributor Author

g2mt commented Jul 25, 2025

Also are you aware this: https://github.com/ikawrakow/ik_llama.cpp/blob/main/examples/speculative/speculative.cpp exists.

I am aware of the example. I'll check it later.

@saood06
Copy link
Collaborator

saood06 commented Jul 25, 2025

I am aware of the example. I'll check it later.

Sorry. I forgot my history. The common one (introduced here: ggml-org/llama.cpp#10362) was done before server: ggml-org/llama.cpp#10455. The common implementation was made to be simpler to understand and work with which is why it came bundled with https://github.com/ggml-org/llama.cpp/tree/8f419181d1c20d8195148680df15b6f093cb1512/examples/speculative-simple

@g2mt
Copy link
Contributor Author

g2mt commented Jul 25, 2025

I'm now able to load the draft model. It seems that the kv-cache tensor names were reused for both models. Prefixing them with the model name fixes it.

@saood06
Copy link
Collaborator

saood06 commented Jul 25, 2025

I'm now able to load the draft model. It seems that the kv-cache tensor names were reused for both models. Prefixing them with the model name fixes it.

Nice. Did you get any accepted tokens?

@g2mt g2mt force-pushed the speculative-port branch from ffb9d71 to 4a41cfd Compare July 25, 2025 08:36
@g2mt
Copy link
Contributor Author

g2mt commented Jul 25, 2025

I think I got it working. For some reason ik_llama's slot.id is offset by 1, which tripped me off a bit.

A simple test of repeating a string shows it working:

curl -s http://localhost:9001/v1/chat/completions \
          -H "Content-Type: application/json" \
          -H "Authorization: Bearer no-key" \
          -d '{"model": "test","messages": [{"role": "user","content": "Repeat the following sentence, as is: The quick brown fox jumped over the lazy dog."}]}'
{"choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"The quick brown fox jumped over the lazy dog."}}],"created":1753433480,"model":"test","object":"chat.completion","usage":{"completion_tokens":14,"prompt_tokens":26,"total_tokens":40},"id":"chatcmpl-QK3CBenhWiSBeeuIs6UGs2yXCV5YpqRO","__verbose":{"content":"The quick brown fox jumped over the lazy dog.","generated_text":"The quick brown fox jumped over the lazy dog.",

Server logs do show the speculative decoding results being accepted:

VERB [            update_slots] speculative decoding result | tid="140737350637888" timestamp=1753433480 id_slot=0 accepted=12 total=13 new_n_past=39

It looks like it's working, but I think more testing is needed. If someone else could post more test results that would be great. I'll open the PR up for review now.

@g2mt g2mt marked this pull request as ready for review July 25, 2025 09:02
@saood06
Copy link
Collaborator

saood06 commented Jul 25, 2025

If someone else could post more test results that would be great. I'll open the PR up for review now.

I'll try to do some tests within a day.

Copy link
Owner

@ikawrakow ikawrakow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make the suggested flash attention setting change. After that let's have @ChicoPinto70 confirm that this fixes their issue, then we can merge.

@usrlocalben
Copy link
Contributor

usrlocalben commented Aug 7, 2025

I grabbed the latest. It still (mostly) works for me.

However, I have observed a problem. Unfortunately it seems to be highly prompt dependent and I haven't been able to work out what the exact conditions are to produce badness.

As best I can tell for now, a prompt that enters a cycle, which is then halted, then regenerated, will get the system into some kind of state that gives 1/10 TG (or worse) throughput from thereon until the server is restarted, regardless of subsequent prompts.

I have an ascii-art generating prompt that may enter a cycle of underscores ______ ______ ______ ______ ______. If I observe this then it's likely that I've reached the bad-state and all prompts from there will give me around 1/15 TG. (~1.0t/s instead of 15t/s)

Not clearing a buffer? Bug in cache reuse + speculation? Bad sampler state? Overrun that harms the speculation config?

I hate to make such a vague report but I can't narrow it down to a clear concept. It almost seems like it's important what the repeating tokens are, although I can't imagine that's actually the cause.

edit: for reference, my config is up above in this thread triggered with ubergarm's IQ3 and IQ4 quants although I suspect the quant itself has nothing to do with it.

@g2mt
Copy link
Contributor Author

g2mt commented Aug 7, 2025

Let me also add the draft acceptance parameters --draft-max, --draft-min since they don't seem to be present.

As best I can tell for now, a prompt that enters a cycle, which is then halted, then regenerated, will get the system into some kind of state that gives 1/10 TG (or worse) throughput from thereon until the server is restarted, regardless of subsequent prompts.

I also encounter this from time to time. It happens mostly at the start of prompt generation when feeding the server new messages. It could be because of additional prompt processing that has to be done, but I haven't looked into it

@ChicoPinto70
Copy link

ChicoPinto70 commented Aug 7, 2025

Make the suggested flash attention setting change. After that let's have @ChicoPinto70 confirm that this fixes their issue, then we can merge.

Hi, guys! I reran the test (in gdb), offloading the draft to gpu (-ngld 64) and adding the -ctkd q8_0 parameter, and the bug vanished. But the speed is still too slow (around 1.1 Tk/s).

And I'm receiving all the time the following msgs:

ggml_backend_cuda_graph_compute: disabling CUDA graphs due to mul_mat_id
ggml_backend_cuda_graph_compute: disabling CUDA graphs due to mul_mat_id
ggml_backend_cuda_graph_compute: disabling CUDA graphs due to mul_mat_id
ggml_backend_cuda_graph_compute: disabling CUDA graphs due to mul_mat_id
ggml_backend_cuda_graph_compute: disabling CUDA graphs due to mul_mat_id
[Thread 0x7fba526fa000 (LWP 174681) exited]
[Thread 0x7fb5cfebe000 (LWP 174680) exited]
[Thread 0x7fb5cf6bd000 (LWP 174679) exited]
[Thread 0x7fba48f6c000 (LWP 174678) exited]

@saood06
Copy link
Collaborator

saood06 commented Aug 8, 2025

I created a separate issue for the GGML_ASSERT as it is not related to this PR, since it can be triggered on a prior commit, and also without using a draft model or speculative at all.

@saood06
Copy link
Collaborator

saood06 commented Aug 8, 2025

I have an ascii-art generating prompt that may enter a cycle of underscores ______ ______ ______ ______ ______. If I observe this then it's likely that I've reached the bad-state and all prompts from there will give me around 1/15 TG. (~1.0t/s instead of 15t/s)

Not clearing a buffer? Bug in cache reuse + speculation? Bad sampler state? Overrun that harms the speculation config?

Another potential cause could be the draft model is suddenly generating much longer drafts with no acceptance rate, at least that would potentially cause the dramatic slowdown being seen.

@ikawrakow
Copy link
Owner

@ChicoPinto70

But the speed is still too slow (around 1.1 Tk/s).

Is this with the debug or release build? How does speculative compare to no speculation with your setup?

@ChicoPinto70
Copy link

@ChicoPinto70

But the speed is still too slow (around 1.1 Tk/s).

Is this with the debug or release build? How does speculative compare to no speculation with your setup?

I ran it with debug (inside gdb) and the output was fine but the speed was too slow. It ran at 1.1T/s when without speculative decoding it runs at ~7T/s.

@ikawrakow
Copy link
Owner

I ran it with debug (inside gdb) and the output was fine but the speed was too slow. It ran at 1.1T/s when without speculative decoding it runs at ~7T/s.

But the run without speculative decoding is also with a debug build inside gdb? Or is it a Release build? To compare, you need to have both running in Release.

@ChicoPinto70
Copy link

ChicoPinto70 commented Aug 8, 2025

I ran it with debug (inside gdb) and the output was fine but the speed was too slow. It ran at 1.1T/s when without speculative decoding it runs at ~7T/s.

But the run without speculative decoding is also with a debug build inside gdb? Or is it a Release build? To compare, you need to have both running in Release.

Ah! I got it. Let me test it again without the speculative decoding in debug. And with the speculative decoding in release.

@ChicoPinto70
Copy link

ChicoPinto70 commented Aug 8, 2025

I ran it with debug (inside gdb) and the output was fine but the speed was too slow. It ran at 1.1T/s when without speculative decoding it runs at ~7T/s.

But the run without speculative decoding is also with a debug build inside gdb? Or is it a Release build? To compare, you need to have both running in Release.

Ikawrakow, I've just reran the tests and these are the results:

Debug: 1.1T/s (with SD) - 1.6 (w/o SD)
Release: 3.5T/s (with SD) - 7.9 (w/o SD)

I ran a small prompt ("Tell me a joke") with the following command:

CUDA_VISIBLE_DEVICES="1,2,0" ./build/bin/llama-server --alias unsloth/DeepSeek-R1-0528-UD-Q3_K_XL -m /home/chico/.lmstudio/models/unsloth/DeepSeek-R1-0528-GGUF/DeepSeek-R1-0528-UD-Q3_K_XL-00001-of-00007.gguf -ngl 64 -c 16384 -mla 3 -fa -amb 1024 -fmoe -t 32 -ctk q8_0 -ot "blk.[0-6].._exps.=CUDA1,blk.(7|8|9|10).._exps.=CUDA2,exps=CPU" --parallel 1 --numa distribute -b 4096 -ub 4096 --no-mmap -ts 1,0,0 -ser 7,1 --host 192.168.0.9 --port 1235 -md /home/chico/.lmstudio/models/jukofyork/DeepSeek-R1-DRAFT-0.6B-v2.0-GGUF/DeepSeek-R1-DRAFT-0.6B-128k-Q4_0.gguf -ngld 64 -ctkd q8_0

@ChicoPinto70
Copy link

ChicoPinto70 commented Aug 8, 2025

Guys, I was afraid my offload strategy was interfering in the speculative decoding implementation. So, I've changed my command to:

CUDA_VISIBLE_DEVICES="1,2,0" ./build/bin/llama-server --alias unsloth/DeepSeek-R1-0528-UD-Q3_K_XL -m /home/chico/.lmstudio/models/unsloth/DeepSeek-R1-0528-GGUF/DeepSeek-R1-0528-UD-Q3_K_XL-00001-of-00007.gguf -ngl 64 -c 32768 -mla 3 -fa -amb 1024 -fmoe -t 32 -ctk q8_0 -ot "blk.(3|4).._exps.=CUDA0,blk.(5|6).._exps.=CUDA1,blk.(7|8)..*_exps.=CUDA2,exps=CPU" --parallel 1 --numa distribute -b 4096 -ub 4096 --no-mmap -ser 7,1 --host 192.168.0.9 --port 1235 -md /home/chico/Downloads/DeepSeek-R1-DRAFT-0.6B-32k-Q4_0.gguf -ngld 64 -ctkd q8_0

And, I also reduce the layers offloaded to gpus to ensure enough memory to the draft model.

Doing that, in release mode, I got 4.8 T/s with SD and 7.6 T/s without it.

P.S. I also changed the draft model to the 32k ctx one

@usrlocalben
Copy link
Contributor

@ChicoPinto70 apologies if this is known already but in case it isn't clear: draft/speculative results are highly dependent on the content being generated. good candidates include tokens streams that are repetitive, or where the space of possibilities for next-token is very tight. code-generation happens to be a good fit, and repetitive code even better. "tell me a joke" is unlikely to yield a high hit%.

also, it's imperative that sampling is deterministic or it will rarely yield a hit.

from my llama-swap config:

  "K2-IQ4_KS-Speculative":
    cmd: >
      ${server_ik} ${api} ${slow_pcie}
      -fa -mla 2 -fmoe
      -b 4096 -ub 4096
      --n-gpu-layers 99
      -c 32000
      -ot "blk\.(1|2|3|4)\.ffn_up_exps=CUDA0,blk\.(1|2|3|4)\.ffn_gate_exps=CUDA0"
      -ot exps=CPU
      --top-k 1 --samplers "top_k"
      -m /path/to/k2/ubergarm/IQ3_KS/Kimi-K2-Instruct-IQ3_KS-00001-of-00010.gguf
      -md /path/to/k2/jukofyork/draft/Kimi-K2-Instruct-DRAFT-0.6B-32k-Q4_0.gguf -ngld 99
    filters:
      strip_params: "temperature, top_p, top_k"

Here I setup top-k, and ensure that clients can't cause e.g. temperature to be used.

short discussion
lengthy discussion

@ChicoPinto70
Copy link

ChicoPinto70 commented Aug 8, 2025

--top-k 1 --samplers "top_k"

Hi, @usrlocalben! You're right! I should pay attention to the content to be generated but I didn't know how to make it right.
Well, I ran the same test, adding the --top-k 1 --samplers "top_k" parameters and changing the prompt to "write a tic-tac-toe game in python" to force a code-generation output.
I didn't mention it before but I always use the own ik_llama-server webui for the tests.

With all that, the TG speed with speculative decoding rose to 5.9 T/s! It's still less than the plain model but we are in the right direction. Thanks!

@ikawrakow
Copy link
Owner

Status?

@saood06
Copy link
Collaborator

saood06 commented Aug 15, 2025

Status?

You said "let's have @ChicoPinto70 confirm that this fixes their issue, then we can merge.".

They still have lowered performance but speculative execution only guarantees preserving accuracy, performance can be better or worse. It is a little odd that it is never on par or better for him in any of the things he tested (code gen with greedy sampling is about as best case as it gets), but even then I think this is fine.

The result_timings object has a way to display info (I didn't actual look at it when testing myself) but maybe @ChicoPinto70 could (it is viewable in the default bundled web_ui). The tools for a user to check if it is helping or hurting performance are there (although it may take some experimentation and research for them to figure it out).

I do think the params for the draft model should be more configurable. SmolLM4 is planned to be MLA (someone from the team reported this on reddit) and overall it seems like it would make for a very good model to train draft models from, but I'm fine with merging it as is, and dealing with that when the time comes.

Edit:

Forgot about this:

I have an ascii-art generating prompt that may enter a cycle of underscores ______ ______ ______ ______ ______. If I observe this then it's likely that I've reached the bad-state and all prompts from there will give me around 1/15 TG. (~1.0t/s instead of 15t/s)

It might be useful to run in verbose (and maybe even with reverting this 946fa65, and see if that can be reproduced, because I do feel like it happens because the draft model is generating a LOT of garbage tokens which could tank performance (as rejecting that many tokens means doing long batches and only using the first token).

@ChicoPinto70
Copy link

ChicoPinto70 commented Aug 15, 2025

The result_timings object has a way to display info (I didn't actual look at it when testing myself) but maybe @ChicoPinto70 could (it is viewable in the default bundled web_ui). The tools for a user to check if it is helping or hurting performance are there (although it may take some experimentation and research for them to figure it out).

Hi, Guys.

For some days, I was trying to figure out some way to improve the speculative decoding performance of the models I use for work. I even asked help for GLM 4.5 and Qwen 3 Coder (running in my local machine).

Than, a coupe days ago, I saw ikawrakow was working in a PR (#689) to improve TG in MoE models. I tried it but, as the models I use are all large MoE, I didn't see any great improvement.

But I wondered: What If the PR helps draft models? So, a couple days ago, I merged this PR with the #689 to check it and, unfortunatelly, I didn't see any improvement..... But, now, I can test speculative decoding with the new GLM 4.5 (I had just tested before with Deepseek because Kimi K2 almost don't fit in my machine and the draft of Qwen 3 don't work with ik_llama).

And, with GLM 4.5 I saw a improvement in TG speed! Small (from 5.3 to 5.9 Tok/s) but a improvement!

Today, I tried to use the result_timing, suggested by Saood06, to see the accepted rate of both drafts (Deepseek and GLM 4.5) to check if it was the culprit. But I couldn´t figure out how to used it.

But, I found the parameter --verbose gives these numbers (in the ERB [ operator()] msg).

Both drafts have an accept rate around 75% for a coding test. So, I can only imagine the difference between both models speeds are caused by the size of the Deepseek. I'm using the unsloth's UD-Q3_K_XL because it is the biggest that fits in my machine (I have a Dual Xeon E5 2699v3 with 256GB and 3x3090) and the draft overhead must be impacting the performance more than any speculative decoding benefit.

Said that, IMHO, this PR is OK to merge with the main.

@g2mt
Copy link
Contributor Author

g2mt commented Aug 16, 2025

Once this is merged, I'll probably start working on getting universal assisted decoding in (ggml-org/llama.cpp#12635). It should allow loading draft models with vocabulary data different from the main model.

@ikawrakow ikawrakow merged commit b837773 into ikawrakow:main Aug 16, 2025
@whatever1983
Copy link

@g2mt:
universal assisted decoding is horrible, having different vocab is just terrible engineering. Eagle3 is way better. The best thing to do is to convert baseten-admin/EAGLE3-gpt-oss-120b-bf16 to GGUF and go from there.

@ChicoPinto70
Copy link

ChicoPinto70 commented Aug 20, 2025

@g2mt , @ikawrakow I discovered why I was having a small performance gain with speculative decoding. My system has 3 gpus. When the draft is loaded, it is spread among the available gpus. But the draft is only a 0.6~0.75B model and this partition impacts the draft performance. To fix that, I forced the draft to be loaded in one gpu only. This way, I got a 15% to 20% increase in the TG speed in ubergarm GLM 4.5 with jukofyork draft model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants