Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 58 additions & 21 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,48 +373,85 @@ llama_token mtp_speculative_gen_draft(
if (!smpl) {
return -1;
}
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
const llama_pos draft_pos = n_past;
const llama_seq_id draft_seq_id = 0;
common_batch_add(mtp_batch, id_last, n_past, {0}, true);

llama_batch batch = llama_batch_init(1, 0, 1);
common_batch_add(batch, id_last, n_past, {0}, true);
mtp_batch.mtp_params.op_type = MTP_OP_DRAFT_GEN;

llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
// Perform the MTP draft generation decode. This writes the MTP layer's
// KV state for the draft token into the cache.
llama_decode(ctx, mtp_batch);
llama_batch_free(mtp_batch);

// CRITICAL: Purge the metadata for the draft token we just wrote.
// This makes the physical cell available again for the main model's validation pass,
// preventing a cache state corruption where two cells map to the same logical position.
llama_kv_cache_seq_rm(ctx, draft_seq_id, draft_pos, draft_pos + 1);

const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_n_vocab(vocab);

llama_token_data_array * cur_p = common_sampler_get_candidates(smpl);

cur_p->size = n_vocab;
for (int i = 0; i < n_vocab; ++i) {
cur_p->data[i].id = i;
cur_p->data[i].logit = llama_get_logits_ith(ctx, last_tok_idx)[i];
cur_p->data[i].logit = llama_get_logits_ith(ctx, 0)[i]; // For a single-token batch, logits are always at index 0.
}
cur_p->sorted = false;

common_sampler_apply_chain(smpl, cur_p);

return cur_p->data[0].id;
}

const llama_token id = cur_p->data[0].id;

llama_batch_free(batch);
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) {
if (batch.n_tokens == 0) {
return;
}

return id;
}
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens);

llama_batch mtp_batch = batch;
if (is_prompt_warmup) {
mtp_batch.mtp_params.op_type = MTP_OP_WARMUP;
} else {
mtp_batch.mtp_params.op_type = MTP_OP_UPDATE_ACCEPTED;
}

void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start, size_t n_tokens) {
mtp_kv_update_data token;
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
mtp_batch.logits[i] = true;
}
llama_decode(ctx, mtp_batch);
}

if (n_tokens < 0) {
n_tokens = tokens.size();
void mtp_accept_tokens(
struct llama_context * ctx,
const std::vector<llama_token> & ids,
int32_t n_past_base,
llama_seq_id seq_id
) {
if (ids.empty()) {
return;
}

for (int i = 0; i < std::min(tokens.size(), n_tokens); ++i) {
token = tokens[i];
//fprintf(stderr, "updating mtp kv cache with token (%d, %d, %d)\n", token.id, token.n_past, (int) (token.tok_idx - batch_start));
// Prepare a resized copy of the validation sinfo to match the number of accepted tokens.
// This sets up the context for a "forced sinfo" decode.
if (!llama_mtp_prepare_sinfo_for_update(ctx, ids.size())) {
return;
}

mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx - batch_start);
// Build a new batch containing only the accepted tokens.
llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
for (size_t i = 0; i < ids.size(); ++i) {
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true);
}

tokens.clear();
}
mtp_update_kv_cache(ctx, accepted_batch, false);

// Clean up the forced state to not affect subsequent, normal decode calls.
llama_mtp_cancel_sinfo_update(ctx);

llama_batch_free(accepted_batch);
}
9 changes: 8 additions & 1 deletion common/speculative.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,11 @@ llama_tokens common_speculative_gen_draft(
const llama_tokens & prompt,
llama_token id_last);

void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start = 0, size_t n_tokens = -1);
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup);

void mtp_accept_tokens(
struct llama_context * ctx,
const std::vector<llama_token> & ids,
int32_t n_past_base,
llama_seq_id seq_id
);
45 changes: 43 additions & 2 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,17 @@ extern "C" {
// - if not: only the last token is output
// )
//
typedef enum {
MTP_OP_NONE,
MTP_OP_WARMUP,
MTP_OP_UPDATE_ACCEPTED,
MTP_OP_DRAFT_GEN,
} llama_mtp_op_type;

typedef struct llama_mtp_params {
llama_mtp_op_type op_type;
} llama_mtp_params;

typedef struct llama_batch {
int32_t n_tokens;

Expand All @@ -230,6 +241,7 @@ extern "C" {
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * logits; // TODO: rename this to "output"
llama_mtp_params mtp_params;
} llama_batch;

enum llama_model_kv_override_type {
Expand Down Expand Up @@ -1454,8 +1466,37 @@ extern "C" {
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);

LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
//
// MTP
//

LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);

/**
* @brief Prepares the context for an MTP KV cache update by creating a resized copy of the last sinfo.
* This is used after speculative validation when only a subset of draft tokens are accepted.
* @param n_accepted The number of tokens that were accepted and for which the sinfo should be resized.
* @return true on success.
*/
LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted);

/**
* @brief Prepares the context for an MTP KV cache update by reusing the sinfo from the last main model decode.
* This is used for the prompt warmup to ensure the MTP and main model KV caches are perfectly aligned.
* @return true on success.
*/
LLAMA_API bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx);

/**
* @brief Clears the forced sinfo state from the context. Must be called after a decode that used a prepared sinfo.
*/
LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx);

/**
* @brief Removes KV cache metadata for a specified sequence and token range.
* This makes the physical cells logically available again without deleting the tensor data.
*/
LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1);

#ifdef __cplusplus
}
Expand Down
15 changes: 8 additions & 7 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -834,13 +834,14 @@ struct llama_batch llama_batch_get_one(

struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
llama_batch batch = {
/*n_tokens =*/ 0,
/*tokens =*/ nullptr,
/*embd =*/ nullptr,
/*pos =*/ nullptr,
/*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr,
/*logits =*/ nullptr,
/*n_tokens =*/ 0,
/*tokens =*/ nullptr,
/*embd =*/ nullptr,
/*pos =*/ nullptr,
/*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr,
/*logits =*/ nullptr,
/*.mtp_params =*/ { MTP_OP_NONE },
};

if (embd) {
Expand Down
Loading