-
Notifications
You must be signed in to change notification settings - Fork 12.7k
server: implement GLM-style MTP #15225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
This is correct - we always alternate between conventional and speculative passes. It's definitely not optimal, but improves flexibility for regular sampling. It allows to change the speculative parameters and even disable it per request, while the logic is quite simple. It should be possible to improve this by keeping track which slots are speculating on each iteration and skip adding tokens to the conventional batch for them. It might be a good idea to implement this separately to avoid huge changes in the logic in a single PR. |
Generally we should try to minimize the changes to On first look, I think the path that involves minimal changes is:
Extracting the MTP logits during
Currently, I am not sure which way is better. The first requires a new API call, while the second might break some existing assumptions (not sure if that's the case yet). In any case, you can avoid this until you get the implementation working with a reasonable speedup. After that, we can discuss further how to best refactor the implementation. |
I don't see an issue with adding a new API for this, and it would be easier to use. |
Out of curiosity, is the API for this expected to be flexible enough that we could jump off of it to add things like Medusa / Eagle style (or IBM Accelerator) self speculative decoding heads? I'm pretty sure they work fairly similarly (depending on the final output embeddings of the current token). Another note: After some consideration I think the expected speedup of the MTP module will depend a lot on the hardware the model's running on, particularly because it's an MoE model. While the next token prediction depends only on the current state, if we're doing self speculative decoding, that's additional forward passes. Those forward passes aren't guaranteed to have the same expert usage patterns, meaning the speedup should be some function of the tokens predicted and the expert re-use coefficient for the tokens verified. So, just noting that if it's implemented and there's not a 2x or 3x increase in T/s, it may not be a skill issue on the part of a contributor, but due to the mathematical nature of the calculation. For people running franken setups with Attention / KV Cache on GPU and MoE FFNs on CPU, it's possible that using previously unused experts in the verification sweep may result in a weird situation where the parallel verification process is actually memory bandwidth bound. Not to discourage the implementation of this, I just wanted to give a heads up so nobody's dejected if the theoretical speedups can't be hit. There should still be at least some speedup, though. |
Thanks all for the suggestions. Will definitely look to refactor into something nicer once correctness can be established.
Yeah, I'd generally recommend that people temper their expectations with this. Especially given these three models only have one MTP head the theoretical performance gain is hard bounded by 2x on the top end, and that's assuming a perfectly efficient implementation and 100% draft acceptance. In the absence of actual data from a working prototype... I'd probably guess that the implementation after this PR will be on the order of 40% speedup, then up to 80% after completing this:
Optimistically, I hope to have an ugly but working prototype done sometime today. |
I've gotten to the point where I can get the MTP head to output stuff but managing KV cache with an external call to a separate MTP graph adds an unbelievable amount of complexity: I think we need to do a forward pass for the MTP layer not just when we're sampling, but for every decode token we run. This goes against the scheduling/batching that we're doing (like we'd probably have to add some form of per-token callback to Think I'll take the principled approach suggested by @ggerganov above and just create a single augmented graph. But on the plus side, from this previous attempt I'm pretty confident the MTP subgraph itself is correct, so it wasn't a total waste of time. 🤪 I'll commit the old branch in a sec in case it ever winds up being useful, but I kind of doubt it (outside of as a reference for constructing the MTP subgraph) |
…nt is unreasonable
On second thought, building a single augmented graph also doesn't work, because we need the main model's sampled token in the MTP subgraph. We could make some shortcut assumptions, like "greedy sample" in the MTP subgraph, but as soon as we fail to match the actual main model sampled token for the first time, the MTP layer's KV cache is invalid. Something along the lines of the original approach might work, management of the MTP subgraph's KV cache could be made easier by using cparams.embeddings = true and LLAMA_POOLING_TYPE_NONE, decoding an entire batch, then running the entire batch through the MTP head (discarding outputs) to keep its cache up to date. |
This is very much a draft/proof of concept I'm playing with, just one idea for an MTP implementation. Planning to test on GLM-4.5 because it's the only model out there that we've preserved NextN tensors for.
From what I can tell
So implementation-wise it seems like
mtp_speculative_gen_draft
in speculative.cpp that is vastly simplified and branch into it in server.cpp when a slot has MTP (versuscommon_speculative_gen_draft
).ctx_dft
in this case as well. It's a bit hacky but I was thinking we could just havectx_dft = ctx
and then have both normal and MTP passes write over the sharedctx
logits. I think this minimizes required code changes elsewhereThis is my first time (1) working with ML stuff outside of python (2) attempting to contribute, so patience is appreciated :)