-
Notifications
You must be signed in to change notification settings - Fork 2
mtp-batch: batch prompt processing #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hey @F1LM1, just wanted to share some progress I've made on the batching issue. I started by exploring the two options we discussed:
The last two commits should fix the major performance problem. Before, the logic was essentially running the full main model graph twice for MTP updates. Now, it properly reuses the state (embeddings and KV cache) from the main model's pass to feed the MTP pass, which saves a huge amount of time. There is one known bug I still need to work on: the cache seems to get desynchronized on large prompts that are split across multiple ubatches. For example, when a 5000-token context is processed in 1500-token ubatches set by the user. After the performance and this bug seem to be okay, I will refactor the changes inside |
|
Hey, lots of good stuff here that gets us much closer to a presentable final product I think. On the other hand, I haven't had time to fully parse everything but a couple of things look like bugs to me. Prompt processing order of operations is wrong. We follow the sequence
This one should be easy to fix. A harder one is that KV cache seems to be getting written to the wrong cells. This also gets confounded by the first issue above. Let's say I have a 20-token prompt and I'm using GLM-4.5-Air (so layers 0-45 are the main model, and layer 46 is MTP). Then what happens is something like
Then some combination of the above steps get repeated. When I have time over the coming days I'll take a look at the debugger and see what exactly is happening, but I think the main crux of the issue is this alternating behavior where we write cache data for layers 0-45, then 46, then 0-45, then 46, etc. This eventually degrades the output of both the main model and the MTP head because both see random gaps in their respective KV caches; at least, it seems that way in my limited testing. The output of a simple one-word prompt "test" seems pretty different between the base implementation and this commit. The output from this commit isn't incoherent per se, so maybe it wouldn't set off any red flags that something was wrong, but its response does seem quite a bit shorter. Re: this
I don't know what exact behavior you were seeing, but it could be related to what I'm seeing above. As for how to fix the second issue, we mostly just need to make sure the MTP update steps that update layer 46 KV cache get placed in the correct slots in KV cache, which is to say, get applied to the same slots we've already written for layers 0-45. This doesn't cause a conflict because KV cache is layer-specific. You can see this in // store to KV cache
{
const auto & k_idxs = inp->get_k_idxs();
const auto & v_idxs = inp->get_v_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
}The way my original implementation does this is quite hacky, it's in this step where I manually construct std::vector<uint32_t> idxs;
idxs.push_back(n_past);
llama_kv_cache_unified::slot_info sinfo = {
/*.s0 =*/ 0,
/*.s1 =*/ 0,
/*.strm =*/ { 0 },
/*.idxs =*/ { idxs },
};
llama_kv_cache_unified::slot_info_vec_t sinfos;
sinfos.push_back(sinfo);
static_cast<llama_kv_cache_unified_context*>(mctx.get())->set_sinfos(sinfos);to make sure that the MTP graph always writes over slot |
|
Hey @F1LM1, thanks for the analysis. It helped me a lot to better understand the problems. I found an issue with the cache, specifically with positions when processing larger prompts. My guess was that something was wrong with the u-batch process, and fortunately, your analysis helped me find where the issue was. I've implemented a small fix for the positions, but it will raise an error if you try to process something larger than your batch size. In my free time, I'm currently working on a fix for the MTP process, as it is incorrectly updating the cache positions. One specific problem I'm addressing is that the mtp_update_kv_cache function is being activated in the generation draft using the wrong hidden state, which messes up the MTP to the point where predictions become more random. I will continue to work on this fix so we can have the proper quality of answers from the model. |
|
@F1LM1 These were my last two commits to fix the crash that occurred when the prompt size exceeded the batch size, alongside other small fixes. Over the last few days, I have been and still am stuck with a problem where the MTP draft's checksum is changing. Contrary to my assumptions, this indicates the MTP process is polluting (or "poisoning") the main model's KV cache. The result is that while you cannot see major issues with small prompts (the output 'is' coherent), for large prompts (e.g., >5000 tokens), the output is terrible. I tried many approaches; one was to block the "store to KV cache" logic inside I will probably now try to clean up the code in an attempt to make it easier to debug and find the real issue and its fix. Regarding the last two bugs you spotted, I believe I have fixed them with the changes to the MTP cache warmup and position handling logic. If you find any other problems, please feel free to share. I believe the code is moving in the right direction, but using |
I haven't looked at the most recent commits yet, will fire up the debugger and take a look tomorrow, but this sounds like the second issue I was describing last time; I'm not sure if I described it precisely. The MTP layer and the main model use the same memory context. Roughly, the KV cells are an N x K size array where N is the number of tokens processed so far and K is the number of layers. This means that it's hard for the MTP KV cache to directly overwrite the same addresses as the main model's KV cache, since they exist in different "columns." However, if you're not careful with the ubatch construction, by default it will assume that you've written to all K columns when you write a new row (a new token), even if you've only written to one (the MTP column). This means you end up with whereas what you need is where each row is a token, and layers 1 through K-1 are set during the main model call, and layer K is set when the MTP head is run. The KV cells are automatically written to in the forward step (if you open the kv cache object, you'll see where they get defined in the graph). I vaguely recall that it was mctx.sinfos that determined which row gets written to, but I don't 100% remember, will have to look at it tomorrow. |
|
I've been digging into the KV cache issue, and the problem is definitely related to how the MTP cache is being handled. I've made some progress and have a few key findings to share. Based on your two previous messages, I started with two main hypotheses:
My latest commits were aimed at fixing both. After a lot of debugging, here is what I could figure out. Finding #1: Positional Integrity
First, I wanted to test your hypothesis about incorrect positions. I added detailed logging to trace every KV write operation for both the main model path and the MTP path. I found a problem in this line: slot.mtp_kv_update_batch.push_back({ ids[i], slot.n_past + i, i });The third parameter ( common_batch_add(accepted_batch, ids[i], slot.n_past + i, { slot.id }, false);This seems to have fixed the problem with incorrect positions. The logs now confirm that our position management is solid. For any given chunk of tokens, the main model layers (e.g., 0-45) are written first, and then the MTP layer (e.g., 46) is written for the exact same token positions. There are no gaps. Here's a summarized log example for a chunk of tokens at positions This proves the Finding #2: MTP Cache Warmup
This is the strangest part. I fixed the warmup logic to correctly populate the MTP cache for the entire prompt, chunk by chunk, using the main model's fresh hidden states for each chunk. The logic in // Inside the main batch processing loop, right after the main llama_decode call
bool needs_mtp_warmup = false;
if (slot_batched && slot_batched->has_mtp) {
if (slot_batched->state == SLOT_STATE_PROCESSING_PROMPT || slot_batched->state == SLOT_STATE_DONE_PROMPT) {
needs_mtp_warmup = true;
}
}
if (needs_mtp_warmup) {
mtp_update_kv_cache(ctx, batch_view, true);
}Logically, this should work. The MTP cache is now fully synchronized with the main model cache. However, when this logic is active, the model's output becomes incoherent and hallucinates on long prompts. But, if I comment out this warmup block, the model's output becomes perfectly coherent and high-quality. As expected, the MTP's initial draft accuracy is low (~45%) because it has no context, but it improves over time as it populates its cache during generation. Current Situation: I'm now stuck on this paradox: a "correct" and full MTP cache warmup poisons the generation quality, while no warmup fixes the quality at the cost of initial MTP performance. This leads me to believe there's a subtle, fundamental issue with how the MTP head's attention mechanism interacts with a long KV history. My sequential method of populating the cache might be creating a state that the MTP head wasn't trained to handle, even if the positions are correct. I wanted to share these findings while they're fresh. I'll continue to investigate this behavior, but I'm very interested to hear if this triggers any new ideas on your end. |
|
Hi @SamuelOliveirads, good to see you diving in here! From what you're saying, it looks like you're using to determine which KV cells are being written to. But this is actually misleading. This only tells you which token positions are recorded in the ubatch. Look carefully at the code in
I claim that the current setup still does not write the layer 46 data into the same cells as the layer 0-45 data, despite the fact that we declare the data to be about the same positions. In other words, after processing 100 tokens, we see How does this arise? It's due to the default way the canonical operations manage ubatch creation/unified KV cache. In particular, these functions don't care about which token positions the ubatch is supposed to represent; they only know that KV cells should never be overwritten unless they were explicitly removed. So even if you create two batches with the same ubatch.pos, they still get written to different cells. To see this in the code, look first at This is why my prototype built its own I've verified this behavior using a debugger and it might be helpful for you to do so as well. Just place a checkpoint near the beginning of Also look at this->v_cells[0].pos. You'll see that In other words, there are two KV cells associated with each token position, one for layers 0-45 and one for layer 46. But the main model forward just sees the entire KV cache including entries 0-200, so it sees 100 tokens of useful KV cache, and then 100 empty tokens, causing it to produce corrupted output. You should check to make sure I haven't just made a silly mistake/looking at the wrong commit or something, but I'm like 90% confident this is the issue. FWIW, I don't think this behavior is completely isolated to the warmup step. It's creating doubled entries during token generation as well, but since the gaps are smaller (alternating between valid KV cache entries and empty ones) the main model quality is affected less. If it turns out this is indeed the issue, a dirty fix is possible and maybe easy using my set_sinfos function. But to make this robust, you'd have to track which cells contain info for given token positions after the main model ubatch, and make sure that sinfos.idxs is set to those same positions for the MTP ubatch. |
My bad, I indeed didn't notice that I was looking at the logical positions rather than the physical cells.
Yes, I took a look and the problem was indeed there. It was due to my lack of deep knowledge on this part of the codebase. You pointed out several times to look at
Yes, it's not. I oversimplified my findings before. I was noticing that a corruption was occurring, but it's subtle and difficult to see with a small context. By disabling the warmup, I was inadvertently masking the issue by starting with a small context, but the corruption was still happening. It just took more generation steps for the incorrect output to become obvious.
That's the point that I want to share. Over the last few days, I implemented a fix based on your suggestion. I now store the In terms of performance, it's the same as your original branch, at least in my configuration. However, after fixing the warmup, I could still spot some subtle randomness in the output, which led me to discover the second issue you predicted: duplicated entries during token generation.
This creates two cache entries for the same logical position, which corrupts the attention state. I'm thinking of how to handle that. My best guess is to make the MTP draft's write to the cache a temporary one, and then purge its metadata immediately after using a function like Other options that I considered were to allow |
|
@F1LM1 I believe that now is ready for another reviews as I didn't find more bugs or problems with performance. |
|
My last five commits were mainly focused on refactoring and cleaning up the code, while also fixing small bugs that I found along the way. For now, I'm looking into performance, as this branch currently has similar speed to the original If I don't find any more bugs, I will start to address the performance problems and apply fixes. This might be done in a new branch, depending on the scope of the required changes. |
Sounds good to me, I'll do some testing tomorrow and merge this one. Probably agree that more performance improvements makes the most sense in a separate PR. |
|
I haven't had the time to fully dive into the behavior today, but it feels vaguely like there might be an off-by-one somewhere after prompt processing. In particular, seems like the model isn't outputting a |
In response to your findings, I ran a small test with about 20 requests and observed a 25% error rate where the model failed to use the For verification, I also ran another 20 requests on your branch and could not reproduce the problem there. Interestingly, I found that using the arguments I don't want to speculate too much, but my initial suspicion is that this could be related to my changes to the graph or the use of I probably won't have time to look into this more deeply until tomorrow or later, but I would definitely appreciate it if you find anything in the meantime. |
|
I've identified and fixed the bug. The issue was that MTP cache operations ( The fix prevents This also explains why |
|
I agree, not seeing any obvious bugs in this revision in early testing. Hope to do some more thorough verification this weekend but tentatively think we should be happy with where this is now. |
|
Hello everyone! First let me say that I love both your guys work regarding implementing MTP functionality with the GLM series! I just wanted to add my findings when testing this latest branch after your many fixes. I am getting better performance with a separate draft model than when using the MTP layers alone and when using MTP + Separate draft model. I am using GLM 4.5 Air at 4bit quant Only Separate Draft Model: 21 tok/sec decoding Only MTP Layers: 13.5 tok/sec decoding MTP + Separate Draft Model: 15.6 tok/sec decoding I am not entirely sure why that is, but I found it interesting enough to share. Once again, thank you all for your continued efforts in making this a reality! |
|
@InfernalDread Thanks for sharing this! Your benchmarks are very insightful. Until now, I haven't tested it with a separate draft model. I assume you're using the one created by the community on Hugging Face? It's interesting that you were able to combine it with the MTP layers. To give you some context on the current progress of MTP in
Regarding performance, MTP is currently slower than not using it. This is because it doesn't yet leverage all the optimization features that Once I have the next PR ready with performance improvements, I would definitely appreciate some benchmarks. If you'd like, I can mention you when we have something ready to test so you can try it out and share your findings. |
|
@SamuelOliveirads That is a very solid plan and I am very impressed with how far you guys have gotten in such a short amount of time! I genuinely cant wait until this project reaches its finale! To answer your first question, yes, I am using a draft model made the community on HuggingFace, specifically this one: jukofyork/GLM-4.5-DRAFT-0.6B-v3.0-GGUF, it works surprisingly well for its size! Also, thank you for the explanation, the fact that the performance I saw was without further optimizations/improvements is a very good sign of what is to come with MTP. Lastly, yes, I would love if you could mention me when the next stage is ready for testing/benchmarking, as I would be glad to help in any way that I can to help confirm the changes made. Thank you once again! |
|
Didn't find suspicious behavior, so merging. Let's keep discussing performance bottlenecks in another PR? (I'll be away from my computer for the next week or so but will keep checking in here) |
For now, this is a proof of concept for batching the MTP KV cache update. My approach was to unify all MTP-related graph execution into the main
llama_decodefunction, which already handles batching and ubatching efficiently.The performance is currently not ideal, and my guess is that it's due to the overhead of building and executing the full, combined graph every time
decodeis called, especially since we now use it for both KV cache updates and single-token draft generation.From here, I was thinking of two main options to optimize this:
llama_decodeto make the unified graph approach more efficient. This would mean investigating why the graph cache isn't being reused and finding ways to avoid redundant computations when we only need the MTP-specific part of the graph.llama_decode_mtpfunction. This would be a stripped-down version ofdecodededicated to running only the MTP graph. This would likely require creating several new helper functions to be shared between both decode paths to avoid too much code duplication.I'm more inclined towards the first option for the long run. As more models adopt MTP, it seems better to have a single, unified
decodepath that can handle it, rather than maintaining two separate ones.Still, I'm not completely sure which path is best, so I wanted to share this draft to open a discussion, while I continue to work on ideas to optimize the unified graph approach.