Skip to content

Commit c2d7c76

Browse files
authored
Merge pull request #3 from SamuelOliveirads/glm4-mtp-batch
mtp-batch: batch prompt processing
2 parents c6237c7 + cae85fe commit c2d7c76

File tree

13 files changed

+685
-470
lines changed

13 files changed

+685
-470
lines changed

common/speculative.cpp

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -373,48 +373,85 @@ llama_token mtp_speculative_gen_draft(
373373
if (!smpl) {
374374
return -1;
375375
}
376+
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
377+
const llama_pos draft_pos = n_past;
378+
const llama_seq_id draft_seq_id = 0;
379+
common_batch_add(mtp_batch, id_last, n_past, {0}, true);
376380

377-
llama_batch batch = llama_batch_init(1, 0, 1);
378-
common_batch_add(batch, id_last, n_past, {0}, true);
381+
mtp_batch.mtp_params.op_type = MTP_OP_DRAFT_GEN;
379382

380-
llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
383+
// Perform the MTP draft generation decode. This writes the MTP layer's
384+
// KV state for the draft token into the cache.
385+
llama_decode(ctx, mtp_batch);
386+
llama_batch_free(mtp_batch);
387+
388+
// CRITICAL: Purge the metadata for the draft token we just wrote.
389+
// This makes the physical cell available again for the main model's validation pass,
390+
// preventing a cache state corruption where two cells map to the same logical position.
391+
llama_kv_cache_seq_rm(ctx, draft_seq_id, draft_pos, draft_pos + 1);
381392

382393
const llama_model * model = llama_get_model(ctx);
383394
const llama_vocab * vocab = llama_model_get_vocab(model);
384395
const int n_vocab = llama_n_vocab(vocab);
385-
386396
llama_token_data_array * cur_p = common_sampler_get_candidates(smpl);
387-
388397
cur_p->size = n_vocab;
389398
for (int i = 0; i < n_vocab; ++i) {
390399
cur_p->data[i].id = i;
391-
cur_p->data[i].logit = llama_get_logits_ith(ctx, last_tok_idx)[i];
400+
cur_p->data[i].logit = llama_get_logits_ith(ctx, 0)[i]; // For a single-token batch, logits are always at index 0.
392401
}
393402
cur_p->sorted = false;
394-
395403
common_sampler_apply_chain(smpl, cur_p);
404+
405+
return cur_p->data[0].id;
406+
}
396407

397-
const llama_token id = cur_p->data[0].id;
398408

399-
llama_batch_free(batch);
409+
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) {
410+
if (batch.n_tokens == 0) {
411+
return;
412+
}
400413

401-
return id;
402-
}
414+
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens);
403415

416+
llama_batch mtp_batch = batch;
417+
if (is_prompt_warmup) {
418+
mtp_batch.mtp_params.op_type = MTP_OP_WARMUP;
419+
} else {
420+
mtp_batch.mtp_params.op_type = MTP_OP_UPDATE_ACCEPTED;
421+
}
404422

405-
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start, size_t n_tokens) {
406-
mtp_kv_update_data token;
423+
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
424+
mtp_batch.logits[i] = true;
425+
}
426+
llama_decode(ctx, mtp_batch);
427+
}
407428

408-
if (n_tokens < 0) {
409-
n_tokens = tokens.size();
429+
void mtp_accept_tokens(
430+
struct llama_context * ctx,
431+
const std::vector<llama_token> & ids,
432+
int32_t n_past_base,
433+
llama_seq_id seq_id
434+
) {
435+
if (ids.empty()) {
436+
return;
410437
}
411438

412-
for (int i = 0; i < std::min(tokens.size(), n_tokens); ++i) {
413-
token = tokens[i];
414-
//fprintf(stderr, "updating mtp kv cache with token (%d, %d, %d)\n", token.id, token.n_past, (int) (token.tok_idx - batch_start));
439+
// Prepare a resized copy of the validation sinfo to match the number of accepted tokens.
440+
// This sets up the context for a "forced sinfo" decode.
441+
if (!llama_mtp_prepare_sinfo_for_update(ctx, ids.size())) {
442+
return;
443+
}
415444

416-
mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx - batch_start);
445+
// Build a new batch containing only the accepted tokens.
446+
llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
447+
for (size_t i = 0; i < ids.size(); ++i) {
448+
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true);
417449
}
418450

419-
tokens.clear();
420-
}
451+
mtp_update_kv_cache(ctx, accepted_batch, false);
452+
453+
// Clean up the forced state to not affect subsequent, normal decode calls.
454+
llama_mtp_cancel_sinfo_update(ctx);
455+
456+
llama_batch_free(accepted_batch);
457+
}

common/speculative.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,11 @@ llama_tokens common_speculative_gen_draft(
4949
const llama_tokens & prompt,
5050
llama_token id_last);
5151

52-
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start = 0, size_t n_tokens = -1);
52+
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup);
53+
54+
void mtp_accept_tokens(
55+
struct llama_context * ctx,
56+
const std::vector<llama_token> & ids,
57+
int32_t n_past_base,
58+
llama_seq_id seq_id
59+
);

include/llama.h

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,17 @@ extern "C" {
221221
// - if not: only the last token is output
222222
// )
223223
//
224+
typedef enum {
225+
MTP_OP_NONE,
226+
MTP_OP_WARMUP,
227+
MTP_OP_UPDATE_ACCEPTED,
228+
MTP_OP_DRAFT_GEN,
229+
} llama_mtp_op_type;
230+
231+
typedef struct llama_mtp_params {
232+
llama_mtp_op_type op_type;
233+
} llama_mtp_params;
234+
224235
typedef struct llama_batch {
225236
int32_t n_tokens;
226237

@@ -230,6 +241,7 @@ extern "C" {
230241
int32_t * n_seq_id;
231242
llama_seq_id ** seq_id;
232243
int8_t * logits; // TODO: rename this to "output"
244+
llama_mtp_params mtp_params;
233245
} llama_batch;
234246

235247
enum llama_model_kv_override_type {
@@ -1454,8 +1466,37 @@ extern "C" {
14541466
ggml_opt_epoch_callback callback_train,
14551467
ggml_opt_epoch_callback callback_eval);
14561468

1457-
LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
1458-
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
1469+
//
1470+
// MTP
1471+
//
1472+
1473+
LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);
1474+
1475+
/**
1476+
* @brief Prepares the context for an MTP KV cache update by creating a resized copy of the last sinfo.
1477+
* This is used after speculative validation when only a subset of draft tokens are accepted.
1478+
* @param n_accepted The number of tokens that were accepted and for which the sinfo should be resized.
1479+
* @return true on success.
1480+
*/
1481+
LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted);
1482+
1483+
/**
1484+
* @brief Prepares the context for an MTP KV cache update by reusing the sinfo from the last main model decode.
1485+
* This is used for the prompt warmup to ensure the MTP and main model KV caches are perfectly aligned.
1486+
* @return true on success.
1487+
*/
1488+
LLAMA_API bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx);
1489+
1490+
/**
1491+
* @brief Clears the forced sinfo state from the context. Must be called after a decode that used a prepared sinfo.
1492+
*/
1493+
LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx);
1494+
1495+
/**
1496+
* @brief Removes KV cache metadata for a specified sequence and token range.
1497+
* This makes the physical cells logically available again without deleting the tensor data.
1498+
*/
1499+
LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1);
14591500

14601501
#ifdef __cplusplus
14611502
}

src/llama-batch.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -834,13 +834,14 @@ struct llama_batch llama_batch_get_one(
834834

835835
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
836836
llama_batch batch = {
837-
/*n_tokens =*/ 0,
838-
/*tokens =*/ nullptr,
839-
/*embd =*/ nullptr,
840-
/*pos =*/ nullptr,
841-
/*n_seq_id =*/ nullptr,
842-
/*seq_id =*/ nullptr,
843-
/*logits =*/ nullptr,
837+
/*n_tokens =*/ 0,
838+
/*tokens =*/ nullptr,
839+
/*embd =*/ nullptr,
840+
/*pos =*/ nullptr,
841+
/*n_seq_id =*/ nullptr,
842+
/*seq_id =*/ nullptr,
843+
/*logits =*/ nullptr,
844+
/*.mtp_params =*/ { MTP_OP_NONE },
844845
};
845846

846847
if (embd) {

0 commit comments

Comments
 (0)