Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
db60623
added getter for nextn layer count and server slot has_mtp property
F1LM1 Aug 11, 2025
e434f87
some work towards building mtp layer graph
F1LM1 Aug 11, 2025
1f477b3
make nextn weights loadable without a crash
F1LM1 Aug 12, 2025
03231da
add model member function to build mtp graph, to be called from specu…
F1LM1 Aug 12, 2025
cf0f7c0
broad thrust of the mtp implementation
F1LM1 Aug 13, 2025
6e9bafc
failed attempt to implement MTP; outputs tokens but KV cache manageme…
F1LM1 Aug 16, 2025
6870f97
added proper KV cache management for MTP layers and slightly refactored
F1LM1 Aug 17, 2025
382135a
fixed mtp kv cache update sequencing after prompt processing
F1LM1 Aug 18, 2025
d72f9d5
kludge-y kv cache management of mtp layer
F1LM1 Aug 19, 2025
471e026
fixed vram leak
F1LM1 Aug 20, 2025
98bc0c6
replace standard sampler with greedy sampler for mtp draft
F1LM1 Aug 26, 2025
9fab53e
fixed mtp kv cache update step in cases where prompt size > n_batch a…
F1LM1 Sep 2, 2025
07670a2
feat: implemented sampling for MTP
SamuelOliveirads Sep 3, 2025
5a5bce8
fix: add sample acceptance
SamuelOliveirads Sep 3, 2025
8742ce0
feat: apply logits + greedy sampler
SamuelOliveirads Sep 6, 2025
c6237c7
Merge pull request #1 from SamuelOliveirads/glm4-moe-mtp
F1LM1 Sep 13, 2025
1318b2d
mtp-batch (wip): move mtp execution to batch format
SamuelOliveirads Sep 14, 2025
042eb8a
mtp-batch (wip): merge mtp and model graph
SamuelOliveirads Sep 22, 2025
df64508
mtp-batch (wip): merge glm graphs
SamuelOliveirads Sep 22, 2025
3da7e7f
mtp-batch (fix): warm mtp cache for small batch size
SamuelOliveirads Sep 24, 2025
75dc25e
mtp-batch (wip): organize batch for mtp cache
SamuelOliveirads Sep 27, 2025
67c6c06
mtp-batch (wip): Isolate MTP graph to prevent host embedding buffer c…
SamuelOliveirads Sep 27, 2025
febd823
mtp-batch (wip): fix how to warmup kv cache for MTP
SamuelOliveirads Oct 5, 2025
5e1d719
mtp-batch (feat): Create and manage sinfo for MTP
SamuelOliveirads Oct 9, 2025
6f74ba3
mtp-batch (fix): prevent mtp draft from polluting the cache
SamuelOliveirads Oct 10, 2025
913af8f
mtp-batch(refactor): Replace MTP boolean flags with an explicit opera…
SamuelOliveirads Oct 10, 2025
a99709d
mtp-batch(refactor): Extract decode context and MTP input logic into …
SamuelOliveirads Oct 10, 2025
b4cbe03
mtp-batch(chore): Fix logit flags for speculative sampling and remove…
SamuelOliveirads Oct 11, 2025
4bcc9e2
mtp-batch(fix): Correctly advance cache head and add MTP documentation
SamuelOliveirads Oct 11, 2025
0127c6b
mtp-batch(chore): Remove final MTP debug logs and dead code
SamuelOliveirads Oct 12, 2025
cae85fe
mtp-batch(fix): avoid logits for mtp kv cache operations
SamuelOliveirads Oct 16, 2025
c2d7c76
Merge pull request #3 from SamuelOliveirads/glm4-mtp-batch
F1LM1 Oct 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,11 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co

llama_sampler_apply(chain, &cur_p);

/*for (int k = 0; k < (int)cur_p.size; ++k) {
LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f)\n",
k, 0, cur_p.data[k].id, cur_p.data[k].p);
}*/

GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");

const llama_token id = cur_p.data[cur_p.selected].id;
Expand Down Expand Up @@ -577,3 +582,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri

return samplers;
}

void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p) {
llama_sampler_apply(gsmpl->chain, cur_p);
}
2 changes: 2 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,5 @@ std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std:

llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
const char * grammar_kind, const char * grammar_data);

void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p);
96 changes: 96 additions & 0 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "log.h"
#include "common.h"
#include "sampling.h"
#include "../src/llama-graph.h"
#include "../src/llama-context.h"

#include <cstring>
#include <algorithm>
Expand Down Expand Up @@ -359,3 +361,97 @@ llama_tokens common_speculative_gen_draft(
}
return result;
}


llama_token mtp_speculative_gen_draft(
struct common_sampler* smpl,
struct llama_context* ctx,
llama_token id_last,
int32_t n_past,
int32_t last_tok_idx) {

if (!smpl) {
return -1;
}
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
const llama_pos draft_pos = n_past;
const llama_seq_id draft_seq_id = 0;
common_batch_add(mtp_batch, id_last, n_past, {0}, true);

mtp_batch.mtp_params.op_type = MTP_OP_DRAFT_GEN;

// Perform the MTP draft generation decode. This writes the MTP layer's
// KV state for the draft token into the cache.
llama_decode(ctx, mtp_batch);
llama_batch_free(mtp_batch);

// CRITICAL: Purge the metadata for the draft token we just wrote.
// This makes the physical cell available again for the main model's validation pass,
// preventing a cache state corruption where two cells map to the same logical position.
llama_kv_cache_seq_rm(ctx, draft_seq_id, draft_pos, draft_pos + 1);

const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_n_vocab(vocab);
llama_token_data_array * cur_p = common_sampler_get_candidates(smpl);
cur_p->size = n_vocab;
for (int i = 0; i < n_vocab; ++i) {
cur_p->data[i].id = i;
cur_p->data[i].logit = llama_get_logits_ith(ctx, 0)[i]; // For a single-token batch, logits are always at index 0.
}
cur_p->sorted = false;
common_sampler_apply_chain(smpl, cur_p);

return cur_p->data[0].id;
}


void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) {
if (batch.n_tokens == 0) {
return;
}

LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens);

llama_batch mtp_batch = batch;
if (is_prompt_warmup) {
mtp_batch.mtp_params.op_type = MTP_OP_WARMUP;
} else {
mtp_batch.mtp_params.op_type = MTP_OP_UPDATE_ACCEPTED;
}

for (int i = 0; i < mtp_batch.n_tokens; ++i) {
mtp_batch.logits[i] = true;
}
llama_decode(ctx, mtp_batch);
}

void mtp_accept_tokens(
struct llama_context * ctx,
const std::vector<llama_token> & ids,
int32_t n_past_base,
llama_seq_id seq_id
) {
if (ids.empty()) {
return;
}

// Prepare a resized copy of the validation sinfo to match the number of accepted tokens.
// This sets up the context for a "forced sinfo" decode.
if (!llama_mtp_prepare_sinfo_for_update(ctx, ids.size())) {
return;
}

// Build a new batch containing only the accepted tokens.
llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
for (size_t i = 0; i < ids.size(); ++i) {
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true);
}

mtp_update_kv_cache(ctx, accepted_batch, false);

// Clean up the forced state to not affect subsequent, normal decode calls.
llama_mtp_cancel_sinfo_update(ctx);

llama_batch_free(accepted_batch);
}
32 changes: 28 additions & 4 deletions common/speculative.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ struct common_speculative_params {
float p_min = 0.75f; // min probability required to accept a token in the draft
};

struct mtp_kv_update_data {
llama_token id;
int32_t n_past;
int32_t tok_idx;
};

struct common_speculative * common_speculative_init(
struct llama_context * ctx_tgt,
struct llama_context * ctx_dft
Expand All @@ -27,9 +33,27 @@ void common_speculative_add_replacement_tgt_dft(
struct common_speculative * spec,
const char *source, const char *dest);


// sample up to n_draft tokens and add them to the batch using the draft model
llama_token mtp_speculative_gen_draft(
struct common_sampler* smpl,
struct llama_context* ctx,
llama_token id_last,
int32_t n_past,
int32_t last_tok_idx);

// sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last);
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last);

void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup);

void mtp_accept_tokens(
struct llama_context * ctx,
const std::vector<llama_token> & ids,
int32_t n_past_base,
llama_seq_id seq_id
);
48 changes: 48 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,17 @@ extern "C" {
// - if not: only the last token is output
// )
//
typedef enum {
MTP_OP_NONE,
MTP_OP_WARMUP,
MTP_OP_UPDATE_ACCEPTED,
MTP_OP_DRAFT_GEN,
} llama_mtp_op_type;

typedef struct llama_mtp_params {
llama_mtp_op_type op_type;
} llama_mtp_params;

typedef struct llama_batch {
int32_t n_tokens;

Expand All @@ -230,6 +241,7 @@ extern "C" {
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * logits; // TODO: rename this to "output"
llama_mtp_params mtp_params;
} llama_batch;

enum llama_model_kv_override_type {
Expand Down Expand Up @@ -495,6 +507,8 @@ extern "C" {

LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);

LLAMA_API int32_t llama_model_n_nextn_layer(const struct llama_model * model);

// Functions to access the model's GGUF metadata scalar values
// - The functions return the length of the string on success, or -1 on failure
// - The output string is always null-terminated and cleared on failure
Expand Down Expand Up @@ -548,6 +562,8 @@ extern "C" {
const char * fname_out,
const llama_model_quantize_params * params);



//
// Adapters
//
Expand Down Expand Up @@ -1450,6 +1466,38 @@ extern "C" {
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);

//
// MTP
//

LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);

/**
* @brief Prepares the context for an MTP KV cache update by creating a resized copy of the last sinfo.
* This is used after speculative validation when only a subset of draft tokens are accepted.
* @param n_accepted The number of tokens that were accepted and for which the sinfo should be resized.
* @return true on success.
*/
LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted);

/**
* @brief Prepares the context for an MTP KV cache update by reusing the sinfo from the last main model decode.
* This is used for the prompt warmup to ensure the MTP and main model KV caches are perfectly aligned.
* @return true on success.
*/
LLAMA_API bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx);

/**
* @brief Clears the forced sinfo state from the context. Must be called after a decode that used a prepared sinfo.
*/
LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx);

/**
* @brief Removes KV cache metadata for a specified sequence and token range.
* This makes the physical cells logically available again without deleting the tensor data.
*/
LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1);

#ifdef __cplusplus
}
#endif
Expand Down
13 changes: 7 additions & 6 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2240,12 +2240,13 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
// These tensors only exist in the last layer(s) and are treated as output tensors
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
// Changed to LLM_TENSOR_LAYER_REPEATING because we saved these under a blk with a non-negative id
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
};

LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
Expand Down
21 changes: 12 additions & 9 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,9 @@ bool llama_batch_allocr::init(
}
}

if (!ok) {
// TEMPORARILY DISABLING THIS SANITY CHECK
// TODO: UNDO THIS IF IT WORKS
/*if (!ok) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
Expand All @@ -284,7 +286,7 @@ bool llama_batch_allocr::init(
__func__, s, s, p0, s, seq_pos_min(s));

return false;
}
}*/
}

if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
Expand Down Expand Up @@ -832,13 +834,14 @@ struct llama_batch llama_batch_get_one(

struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
llama_batch batch = {
/*n_tokens =*/ 0,
/*tokens =*/ nullptr,
/*embd =*/ nullptr,
/*pos =*/ nullptr,
/*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr,
/*logits =*/ nullptr,
/*n_tokens =*/ 0,
/*tokens =*/ nullptr,
/*embd =*/ nullptr,
/*pos =*/ nullptr,
/*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr,
/*logits =*/ nullptr,
/*.mtp_params =*/ { MTP_OP_NONE },
};

if (embd) {
Expand Down
Loading