Skip to content
Draft
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
db60623
added getter for nextn layer count and server slot has_mtp property
F1LM1 Aug 11, 2025
e434f87
some work towards building mtp layer graph
F1LM1 Aug 11, 2025
1f477b3
make nextn weights loadable without a crash
F1LM1 Aug 12, 2025
03231da
add model member function to build mtp graph, to be called from specu…
F1LM1 Aug 12, 2025
cf0f7c0
broad thrust of the mtp implementation
F1LM1 Aug 13, 2025
6e9bafc
failed attempt to implement MTP; outputs tokens but KV cache manageme…
F1LM1 Aug 16, 2025
6870f97
added proper KV cache management for MTP layers and slightly refactored
F1LM1 Aug 17, 2025
382135a
fixed mtp kv cache update sequencing after prompt processing
F1LM1 Aug 18, 2025
d72f9d5
kludge-y kv cache management of mtp layer
F1LM1 Aug 19, 2025
471e026
fixed vram leak
F1LM1 Aug 20, 2025
98bc0c6
replace standard sampler with greedy sampler for mtp draft
F1LM1 Aug 26, 2025
9fab53e
fixed mtp kv cache update step in cases where prompt size > n_batch a…
F1LM1 Sep 2, 2025
07670a2
feat: implemented sampling for MTP
SamuelOliveirads Sep 3, 2025
5a5bce8
fix: add sample acceptance
SamuelOliveirads Sep 3, 2025
8742ce0
feat: apply logits + greedy sampler
SamuelOliveirads Sep 6, 2025
c6237c7
Merge pull request #1 from SamuelOliveirads/glm4-moe-mtp
F1LM1 Sep 13, 2025
1318b2d
mtp-batch (wip): move mtp execution to batch format
SamuelOliveirads Sep 14, 2025
042eb8a
mtp-batch (wip): merge mtp and model graph
SamuelOliveirads Sep 22, 2025
df64508
mtp-batch (wip): merge glm graphs
SamuelOliveirads Sep 22, 2025
3da7e7f
mtp-batch (fix): warm mtp cache for small batch size
SamuelOliveirads Sep 24, 2025
75dc25e
mtp-batch (wip): organize batch for mtp cache
SamuelOliveirads Sep 27, 2025
67c6c06
mtp-batch (wip): Isolate MTP graph to prevent host embedding buffer c…
SamuelOliveirads Sep 27, 2025
febd823
mtp-batch (wip): fix how to warmup kv cache for MTP
SamuelOliveirads Oct 5, 2025
5e1d719
mtp-batch (feat): Create and manage sinfo for MTP
SamuelOliveirads Oct 9, 2025
6f74ba3
mtp-batch (fix): prevent mtp draft from polluting the cache
SamuelOliveirads Oct 10, 2025
913af8f
mtp-batch(refactor): Replace MTP boolean flags with an explicit opera…
SamuelOliveirads Oct 10, 2025
a99709d
mtp-batch(refactor): Extract decode context and MTP input logic into …
SamuelOliveirads Oct 10, 2025
b4cbe03
mtp-batch(chore): Fix logit flags for speculative sampling and remove…
SamuelOliveirads Oct 11, 2025
4bcc9e2
mtp-batch(fix): Correctly advance cache head and add MTP documentation
SamuelOliveirads Oct 11, 2025
0127c6b
mtp-batch(chore): Remove final MTP debug logs and dead code
SamuelOliveirads Oct 12, 2025
cae85fe
mtp-batch(fix): avoid logits for mtp kv cache operations
SamuelOliveirads Oct 16, 2025
c2d7c76
Merge pull request #3 from SamuelOliveirads/glm4-mtp-batch
F1LM1 Oct 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,11 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co

llama_sampler_apply(chain, &cur_p);

/*for (int k = 0; k < (int)cur_p.size; ++k) {
LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f)\n",
k, 0, cur_p.data[k].id, cur_p.data[k].p);
}*/

GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");

const llama_token id = cur_p.data[cur_p.selected].id;
Expand Down
73 changes: 73 additions & 0 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "log.h"
#include "common.h"
#include "sampling.h"
#include "../src/llama-graph.h"
#include "../src/llama-context.h"

#include <cstring>
#include <algorithm>
Expand Down Expand Up @@ -359,3 +361,74 @@ llama_tokens common_speculative_gen_draft(
}
return result;
}


llama_token mtp_speculative_gen_draft(
struct common_sampler* smpl,
struct llama_context* ctx,
llama_token id_last,
int32_t n_past,
int32_t last_tok_idx) {

llama_token token_data[] = { id_last };
llama_pos pos_data[] = { n_past };
int32_t n_seq_id_data[] = { 1 };
llama_seq_id seq_id_data_internal[] = { 0 };
llama_seq_id* seq_id_data[] = {seq_id_data_internal};
int8_t logits_data[] = { (int8_t) (smpl != nullptr) };

llama_batch batch = {
/*.n_tokens = */ 1,
/*.token = */ token_data,
/*.embd = */ nullptr,
/*.pos = */ pos_data,
/*.n_seq_id = */ n_seq_id_data,
/*.seq_id = */ seq_id_data,
/*.logits = */ logits_data
};

return llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
//LOG_INF("updating kv cache for n_past: %d\n", n_past);

/*
if (!smpl) {
return -1;
}
else {
common_sampler_sample(smpl, ctx, last_tok_idx, true);
const auto* cur_p = common_sampler_get_candidates(smpl);

//for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) {
// LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
// k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
//}

const llama_token id = cur_p->data[0].id;
return id;
}
*/
// LOG_INF("cur_p->size: %d\n", cur_p->size);


// add drafted token for each sequence

// skip accepting draft token -- since we're only drafting one token this can't affect future outputs
// smpl will accept the token if it doesn't get rejected by main model later
// common_sampler_accept(smpl, id, true);

//llama_tokens result;
//result.reserve(1);
//result.push_back(id);
//return result;
}


void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens) {
mtp_kv_update_data token;
for (int i = 0; i < tokens.size(); ++i) {
token = tokens[i];
mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx);
}

tokens.clear();
}
17 changes: 17 additions & 0 deletions common/speculative.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ struct common_speculative_params {
float p_min = 0.75f; // min probability required to accept a token in the draft
};

struct mtp_kv_update_data {
llama_token id;
int32_t n_past;
int32_t tok_idx;
};

struct common_speculative * common_speculative_init(
struct llama_context * ctx_tgt,
struct llama_context * ctx_dft
Expand All @@ -27,9 +33,20 @@ void common_speculative_add_replacement_tgt_dft(
struct common_speculative * spec,
const char *source, const char *dest);


// sample up to n_draft tokens and add them to the batch using the draft model
llama_token mtp_speculative_gen_draft(
struct common_sampler* smpl,
struct llama_context* ctx,
llama_token id_last,
int32_t n_past,
int32_t last_tok_idx);

// sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last);

void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens);
7 changes: 7 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,8 @@ extern "C" {

LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);

LLAMA_API int32_t llama_model_n_nextn_layer(const struct llama_model * model);

// Functions to access the model's GGUF metadata scalar values
// - The functions return the length of the string on success, or -1 on failure
// - The output string is always null-terminated and cleared on failure
Expand Down Expand Up @@ -548,6 +550,8 @@ extern "C" {
const char * fname_out,
const llama_model_quantize_params * params);



//
// Adapters
//
Expand Down Expand Up @@ -1450,6 +1454,9 @@ extern "C" {
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);

LLAMA_API llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);

#ifdef __cplusplus
}
#endif
Expand Down
13 changes: 7 additions & 6 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2240,12 +2240,13 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
// These tensors only exist in the last layer(s) and are treated as output tensors
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
// Changed to LLM_TENSOR_LAYER_REPEATING because we saved these under a blk with a non-negative id
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
};

LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
Expand Down
6 changes: 4 additions & 2 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,9 @@ bool llama_batch_allocr::init(
}
}

if (!ok) {
// TEMPORARILY DISABLING THIS SANITY CHECK
// TODO: UNDO THIS IF IT WORKS
/*if (!ok) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
Expand All @@ -284,7 +286,7 @@ bool llama_batch_allocr::init(
__func__, s, s, p0, s, seq_pos_min(s));

return false;
}
}*/
}

if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
Expand Down
Loading