Skip to content

Commit 6f74ba3

Browse files
mtp-batch (fix): prevent mtp draft from polluting the cache
1 parent 5e1d719 commit 6f74ba3

File tree

5 files changed

+38
-5
lines changed

5 files changed

+38
-5
lines changed

common/speculative.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,8 @@ llama_token mtp_speculative_gen_draft(
374374
return -1;
375375
}
376376
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
377+
const llama_pos draft_pos = n_past;
378+
const llama_seq_id draft_seq_id = 0;
377379
common_batch_add(mtp_batch, id_last, n_past, {0}, true);
378380

379381
mtp_batch.update_mtp_kv = false;
@@ -387,6 +389,8 @@ llama_token mtp_speculative_gen_draft(
387389
llama_decode(ctx, mtp_batch);
388390
llama_batch_free(mtp_batch);
389391

392+
llama_kv_cache_seq_rm(ctx, draft_seq_id, draft_pos, draft_pos + 1);
393+
390394
const llama_model * model = llama_get_model(ctx);
391395
const llama_vocab * vocab = llama_model_get_vocab(model);
392396
const int n_vocab = llama_n_vocab(vocab);

include/llama.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1460,9 +1460,13 @@ extern "C" {
14601460
LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);
14611461

14621462
LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted);
1463-
1463+
1464+
LLAMA_API bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx);
1465+
14641466
LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx);
14651467

1468+
LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1);
1469+
14661470
#ifdef __cplusplus
14671471
}
14681472
#endif

src/llama-context.cpp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3105,6 +3105,20 @@ void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float
31053105
ctx->draft_input_hidden_state = hidden_state;
31063106
}
31073107

3108+
bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx) {
3109+
auto * kvd = static_cast<llama_context_kv_cache_data *>(ctx->kv_cache_data);
3110+
const auto & last_sinfo = kvd->last_main_model_sinfos;
3111+
3112+
if (last_sinfo.empty()) {
3113+
LLAMA_LOG_ERROR("%s: The main call sinfo is not available for warmup.\n", __func__);
3114+
return false;
3115+
}
3116+
3117+
kvd->forced_sinfos = &last_sinfo;
3118+
return true;
3119+
}
3120+
3121+
31083122
bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted) {
31093123
auto * kvd = static_cast<llama_context_kv_cache_data *>(ctx->kv_cache_data);
31103124
const auto & last_sinfo = kvd->last_main_model_sinfos;
@@ -3126,4 +3140,14 @@ bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_acc
31263140
void llama_mtp_cancel_sinfo_update(struct llama_context * ctx) {
31273141
auto * kvd = static_cast<llama_context_kv_cache_data *>(ctx->kv_cache_data);
31283142
kvd->forced_sinfos = nullptr;
3129-
}
3143+
}
3144+
3145+
void llama_context::kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
3146+
if (memory) {
3147+
static_cast<llama_kv_cache_unified *>(memory.get())->seq_rm(seq_id, p0, p1);
3148+
}
3149+
}
3150+
3151+
void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
3152+
ctx->kv_cache_seq_rm(seq_id, p0, p1);
3153+
}

src/llama-context.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ struct llama_context {
100100
int32_t il_start,
101101
int32_t il_end);
102102

103+
void kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1);
104+
103105
// process a single ubatch with a specific graph type
104106
// if memory_context is provided, it will be applied first to the context's memory
105107
// ret contains the status of the graph computation

tools/server/server.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3520,11 +3520,10 @@ struct server_context {
35203520
needs_mtp_warmup = true;
35213521
}
35223522
}
3523-
3523+
35243524
if (needs_mtp_warmup) {
3525-
if (llama_mtp_prepare_sinfo_for_update(ctx, batch_view.n_tokens)) {
3525+
if (llama_mtp_prepare_sinfo_for_warmup(ctx)) {
35263526
mtp_update_kv_cache(ctx, batch_view, true);
3527-
35283527
llama_mtp_cancel_sinfo_update(ctx);
35293528
} else {
35303529
LOG_ERR("%s: Failed to prepare the MTP symphony for warmup.", __func__);

0 commit comments

Comments
 (0)