Skip to content

Commit 4bcc9e2

Browse files
mtp-batch(fix): Correctly advance cache head and add MTP documentation
1 parent b4cbe03 commit 4bcc9e2

File tree

4 files changed

+45
-18
lines changed

4 files changed

+45
-18
lines changed

common/speculative.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,17 +436,21 @@ void mtp_accept_tokens(
436436
return;
437437
}
438438

439+
// Prepare a resized copy of the validation sinfo to match the number of accepted tokens.
440+
// This sets up the context for a "forced sinfo" decode.
439441
if (!llama_mtp_prepare_sinfo_for_update(ctx, ids.size())) {
440442
return;
441443
}
442444

445+
// Build a new batch containing only the accepted tokens.
443446
llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
444447
for (size_t i = 0; i < ids.size(); ++i) {
445448
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true);
446449
}
447450

448451
mtp_update_kv_cache(ctx, accepted_batch, false);
449452

453+
// Clean up the forced state to not affect subsequent, normal decode calls.
450454
llama_mtp_cancel_sinfo_update(ctx);
451455

452456
llama_batch_free(accepted_batch);

include/llama.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,14 +1466,36 @@ extern "C" {
14661466
ggml_opt_epoch_callback callback_train,
14671467
ggml_opt_epoch_callback callback_eval);
14681468

1469+
//
1470+
// MTP
1471+
//
1472+
14691473
LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);
14701474

1475+
/**
1476+
* @brief Prepares the context for an MTP KV cache update by creating a resized copy of the last sinfo.
1477+
* This is used after speculative validation when only a subset of draft tokens are accepted.
1478+
* @param n_accepted The number of tokens that were accepted and for which the sinfo should be resized.
1479+
* @return true on success.
1480+
*/
14711481
LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted);
14721482

1483+
/**
1484+
* @brief Prepares the context for an MTP KV cache update by reusing the sinfo from the last main model decode.
1485+
* This is used for the prompt warmup to ensure the MTP and main model KV caches are perfectly aligned.
1486+
* @return true on success.
1487+
*/
14731488
LLAMA_API bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx);
14741489

1490+
/**
1491+
* @brief Clears the forced sinfo state from the context. Must be called after a decode that used a prepared sinfo.
1492+
*/
14751493
LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx);
14761494

1495+
/**
1496+
* @brief Removes KV cache metadata for a specified sequence and token range.
1497+
* This makes the physical cells logically available again without deleting the tensor data.
1498+
*/
14771499
LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1);
14781500

14791501
#ifdef __cplusplus

src/llama-context.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ struct llama_context {
3030

3131
~llama_context();
3232

33+
// The llama_context manages significant resources (GPU memory, file handles, PImpl data)
34+
// and is fundamentally a non-copyable, non-movable object. Deleting these special
35+
// member functions enforces this rule and is also technically required to allow the
36+
// PImpl pattern (via unique_ptr or void*) with an incomplete type in the header.
3337
llama_context(const llama_context &) = delete;
3438
llama_context & operator=(const llama_context &) = delete;
3539
llama_context(llama_context &&) = delete;

src/llama-kv-cache-unified.cpp

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,10 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
977977
}
978978

979979
void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, bool is_inplace_update) {
980+
// For "in-place" updates (MTP warmup/accept), we only update the tensor data.
981+
// The cell metadata (logical position, sequence ID) has already been set
982+
// by the main model's pass. We must skip all metadata modifications
983+
// to prevent `pos_set` from asserting on an already-set cell.
980984
if (!is_inplace_update) {
981985
// keep track of the max sequence position that we would overwrite with this ubatch
982986
// for non-SWA cache, this would be always empty
@@ -995,17 +999,12 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
995999

9961000
const auto idx = sinfo.idxs[s][ii];
9971001

998-
if (!is_inplace_update) {
999-
if (!cells.is_empty(idx)) {
1000-
assert(cells.seq_count(idx) == 1);
1001-
1002-
const llama_seq_id seq_id = cells.seq_get(idx);
1003-
const llama_pos pos = cells.pos_get(idx);
1004-
1005-
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
1006-
1007-
cells.rm(idx);
1008-
}
1002+
if (!cells.is_empty(idx)) {
1003+
assert(cells.seq_count(idx) == 1);
1004+
const llama_seq_id seq_id = cells.seq_get(idx);
1005+
const llama_pos pos = cells.pos_get(idx);
1006+
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
1007+
cells.rm(idx);
10091008
}
10101009

10111010
cells.pos_set(idx, ubatch.pos[i]);
@@ -1029,19 +1028,17 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
10291028
auto & cells = v_cells[seq_to_stream[s]];
10301029

10311030
if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
1032-
LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
1033-
__func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
10341031

10351032
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
10361033
}
10371034
}
1035+
}
10381036

1039-
// move the head at the end of the slot
1040-
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1041-
auto & head = v_heads[sinfo.strm[s]];
1037+
// move the head at the end of the slot
1038+
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1039+
auto & head = v_heads[sinfo.strm[s]];
10421040

1043-
head = sinfo.idxs[s].back() + 1;
1044-
}
1041+
head = sinfo.idxs[s].back() + 1;
10451042
}
10461043
}
10471044

0 commit comments

Comments
 (0)