Skip to content

Commit 9fab53e

Browse files
committed
fixed mtp kv cache update step in cases where prompt size > n_batch and n_ubatch
1 parent 98bc0c6 commit 9fab53e

File tree

3 files changed

+27
-8
lines changed

3 files changed

+27
-8
lines changed

common/speculative.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,11 +423,18 @@ llama_token mtp_speculative_gen_draft(
423423
}
424424

425425

426-
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens) {
426+
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start, size_t n_tokens) {
427427
mtp_kv_update_data token;
428-
for (int i = 0; i < tokens.size(); ++i) {
428+
429+
if (n_tokens < 0) {
430+
n_tokens = tokens.size();
431+
}
432+
433+
for (int i = 0; i < std::min(tokens.size(), n_tokens); ++i) {
429434
token = tokens[i];
430-
mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx);
435+
//fprintf(stderr, "updating mtp kv cache with token (%d, %d, %d)\n", token.id, token.n_past, (int) (token.tok_idx - batch_start));
436+
437+
mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx - batch_start);
431438
}
432439

433440
tokens.clear();

common/speculative.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ llama_tokens common_speculative_gen_draft(
4949
const llama_tokens & prompt,
5050
llama_token id_last);
5151

52-
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens);
52+
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start = 0, size_t n_tokens = -1);

tools/server/server.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,9 +1405,14 @@ struct server_slot {
14051405
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
14061406
// also we cannot split if the pooling would require any past tokens
14071407
bool can_split() const {
1408+
//fprintf(stderr, "need_embd() %d\n", need_embd());
1409+
//fprintf(stderr, "llama_get_memory(ctx) %d\n", llama_get_memory(ctx) != nullptr);
1410+
//fprintf(stderr, "POOLING_TYPE check %d\n", llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
1411+
14081412
return
14091413
!need_embd() ||
1410-
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
1414+
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST) ||
1415+
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_NONE); // this seems to save embeddings for whole batch?
14111416
}
14121417

14131418
bool can_batch_with(server_slot & other_slot) const {
@@ -3508,6 +3513,13 @@ struct server_context {
35083513
continue; // continue loop of n_batch
35093514
}
35103515

3516+
for (auto & slot : slots) {
3517+
// This should only trigger on a non-empty update batch once, after prompt processing but not during token generation
3518+
if (slot.has_mtp) {
3519+
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, i, n_tokens);
3520+
}
3521+
}
3522+
35113523
// move the head of the batch forward with the number of tokens we just processed
35123524
i_next = i + n_tokens;
35133525

@@ -3552,9 +3564,9 @@ struct server_context {
35523564
common_sampler_accept(slot.smpl, id, true);
35533565

35543566
// This should only trigger on a non-empty update batch once, after prompt processing but not during token generation
3555-
if (slot.has_mtp) {
3556-
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch);
3557-
}
3567+
//if (slot.has_mtp) {
3568+
// mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch);
3569+
//}
35583570

35593571
slot.n_decoded += 1;
35603572

0 commit comments

Comments
 (0)