Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 40 additions & 53 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,46 +363,60 @@ llama_tokens common_speculative_gen_draft(
}


llama_token mtp_speculative_gen_draft(
llama_token mtp_update_and_draft(
struct common_sampler* smpl,
struct llama_context* ctx,
llama_token id_last,
int32_t n_past,
int32_t last_tok_idx) {
const llama_tokens & ids_accepted,
int32_t n_past_base,
llama_token id_base
) {
if (!smpl) { return -1; }

if (!smpl) {
return -1;
}
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
const llama_pos draft_pos = n_past;
const llama_seq_id draft_seq_id = 0;
common_batch_add(mtp_batch, id_last, n_past, {0}, true);
const int n_update = ids_accepted.size();
const int n_total = n_update + 1;

mtp_batch.mtp_params.op_type = MTP_OP_DRAFT_GEN;
const int32_t n_past_update_start = n_past_base - n_update;

llama_batch batch = llama_batch_init(n_update + 1, 0, 1);
for (int i = 0; i < n_update; ++i) {
common_batch_add(batch, ids_accepted[i], n_past_update_start + i, { 0 }, true);
}
common_batch_add(batch, id_base, n_past_base, { 0 }, true);
batch.mtp_params.op_type = MTP_OP_UPDATE_AND_DRAFT;

// Perform the MTP draft generation decode. This writes the MTP layer's
// KV state for the draft token into the cache.
llama_decode(ctx, mtp_batch);
llama_batch_free(mtp_batch);
if (!llama_mtp_prepare_sinfo_for_update_and_draft(ctx, n_update)) {
LOG_ERR("[MTP-FLOW] Failed to prepare hybrid sinfo. Aborting.\n");
llama_batch_free(batch);
return -1;
}

// CRITICAL: Purge the metadata for the draft token we just wrote.
// This makes the physical cell available again for the main model's validation pass,
// preventing a cache state corruption where two cells map to the same logical position.
llama_kv_cache_seq_rm(ctx, draft_seq_id, draft_pos, draft_pos + 1);
llama_decode(ctx, batch);

llama_mtp_cancel_sinfo_update(ctx);

const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_n_vocab(vocab);
llama_token_data_array * cur_p = common_sampler_get_candidates(smpl);
cur_p->size = n_vocab;
for (int i = 0; i < n_vocab; ++i) {
cur_p->data[i].id = i;
cur_p->data[i].logit = llama_get_logits_ith(ctx, 0)[i]; // For a single-token batch, logits are always at index 0.
cur_p->data[i].logit = llama_get_logits_ith(ctx, n_total - 1)[i];
}
cur_p->sorted = false;
common_sampler_apply_chain(smpl, cur_p);

return cur_p->data[0].id;
const llama_token draft_id = cur_p->data[0].id;
llama_batch_free(batch);

// CRITICAL: Purge the metadata for the draft token we just wrote.
// This makes the physical cell available again for the main model's validation pass,
// preventing a cache state corruption where two cells map to the same logical position.
const llama_pos draft_pos = n_past_base;
LOG_INF("tokens being deleted: %d e %d\n", draft_pos, draft_pos + 1);
llama_kv_cache_seq_rm(ctx, 0, draft_pos, draft_pos + 1);

return draft_id;
}


Expand All @@ -423,35 +437,8 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
mtp_batch.logits[i] = true;
}
const int64_t t_start_us = ggml_time_us();
llama_decode(ctx, mtp_batch);
}

void mtp_accept_tokens(
struct llama_context * ctx,
const std::vector<llama_token> & ids,
int32_t n_past_base,
llama_seq_id seq_id
) {
if (ids.empty()) {
return;
}

// Prepare a resized copy of the validation sinfo to match the number of accepted tokens.
// This sets up the context for a "forced sinfo" decode.
if (!llama_mtp_prepare_sinfo_for_update(ctx, ids.size())) {
return;
}

// Build a new batch containing only the accepted tokens.
llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
for (size_t i = 0; i < ids.size(); ++i) {
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true);
}

mtp_update_kv_cache(ctx, accepted_batch, false);

// Clean up the forced state to not affect subsequent, normal decode calls.
llama_mtp_cancel_sinfo_update(ctx);

llama_batch_free(accepted_batch);
}
const int64_t t_end_us = ggml_time_us();
LOG_INF("[PERF-MTP] mtp_update_kv_cache internal decode (op=%d): %.2f ms\n", (int)mtp_batch.mtp_params.op_type, (t_end_us - t_start_us) / 1000.0);
}
18 changes: 6 additions & 12 deletions common/speculative.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ void common_speculative_add_replacement_tgt_dft(


// sample up to n_draft tokens and add them to the batch using the draft model
llama_token mtp_speculative_gen_draft(
llama_token mtp_update_and_draft(
struct common_sampler* smpl,
struct llama_context* ctx,
llama_token id_last,
int32_t n_past,
int32_t last_tok_idx);
const llama_tokens & ids_accepted,
int32_t n_past_base,
llama_token id_base
);

// sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens common_speculative_gen_draft(
Expand All @@ -49,11 +50,4 @@ llama_tokens common_speculative_gen_draft(
const llama_tokens & prompt,
llama_token id_last);

void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup);

void mtp_accept_tokens(
struct llama_context * ctx,
const std::vector<llama_token> & ids,
int32_t n_past_base,
llama_seq_id seq_id
);
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup);
8 changes: 4 additions & 4 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ extern "C" {
MTP_OP_NONE,
MTP_OP_WARMUP,
MTP_OP_UPDATE_ACCEPTED,
MTP_OP_DRAFT_GEN,
MTP_OP_UPDATE_AND_DRAFT
} llama_mtp_op_type;

typedef struct llama_mtp_params {
Expand Down Expand Up @@ -1471,15 +1471,15 @@ extern "C" {
//

LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);

/**
* @brief Prepares the context for an MTP KV cache update by creating a resized copy of the last sinfo.
* This is used after speculative validation when only a subset of draft tokens are accepted.
* @param n_accepted The number of tokens that were accepted and for which the sinfo should be resized.
* @return true on success.
*/
LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted);
LLAMA_API bool llama_mtp_prepare_sinfo_for_update_and_draft(struct llama_context * ctx, size_t n_update);

/**
* @brief Prepares the context for an MTP KV cache update by reusing the sinfo from the last main model decode.
* This is used for the prompt warmup to ensure the MTP and main model KV caches are perfectly aligned.
Expand Down
Loading