Skip to content

Commit 913af8f

Browse files
mtp-batch(refactor): Replace MTP boolean flags with an explicit operation enum
1 parent 6f74ba3 commit 913af8f

File tree

8 files changed

+113
-108
lines changed

8 files changed

+113
-108
lines changed

common/speculative.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -378,17 +378,21 @@ llama_token mtp_speculative_gen_draft(
378378
const llama_seq_id draft_seq_id = 0;
379379
common_batch_add(mtp_batch, id_last, n_past, {0}, true);
380380

381-
mtp_batch.update_mtp_kv = false;
382-
mtp_batch.use_mtp_head = true;
381+
mtp_batch.mtp_params.op_type = MTP_OP_DRAFT_GEN;
383382

384-
LOG_INF("[DEBUG-DRAFT-CALL] Calling llama_decode for draft. update_mtp_kv=%s, use_mtp_head=%s\n",
385-
mtp_batch.update_mtp_kv ? "true" : "false",
386-
mtp_batch.use_mtp_head ? "true" : "false"
387-
);
383+
// LOG_INF("[DEBUG-DRAFT-CALL] Calling llama_decode for draft. update_mtp_kv=%s, use_mtp_head=%s\n",
384+
// mtp_batch.update_mtp_kv ? "true" : "false",
385+
// mtp_batch.use_mtp_head ? "true" : "false"
386+
// );
388387

388+
// Perform the MTP draft generation decode. This writes the MTP layer's
389+
// KV state for the draft token into the cache.
389390
llama_decode(ctx, mtp_batch);
390391
llama_batch_free(mtp_batch);
391392

393+
// CRITICAL: Purge the metadata for the draft token we just wrote.
394+
// This makes the physical cell available again for the main model's validation pass,
395+
// preventing a cache state corruption where two cells map to the same logical position.
392396
llama_kv_cache_seq_rm(ctx, draft_seq_id, draft_pos, draft_pos + 1);
393397

394398
const llama_model * model = llama_get_model(ctx);
@@ -398,7 +402,7 @@ llama_token mtp_speculative_gen_draft(
398402
cur_p->size = n_vocab;
399403
for (int i = 0; i < n_vocab; ++i) {
400404
cur_p->data[i].id = i;
401-
cur_p->data[i].logit = llama_get_logits_ith(ctx, 0)[i]; // TODO: check if position 0 is the right
405+
cur_p->data[i].logit = llama_get_logits_ith(ctx, 0)[i]; // For a single-token batch, logits are always at index 0.
402406
}
403407
cur_p->sorted = false;
404408
common_sampler_apply_chain(smpl, cur_p);
@@ -415,9 +419,11 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b
415419
LOG_INF("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens);
416420

417421
llama_batch mtp_batch = batch;
418-
mtp_batch.update_mtp_kv = true;
419-
mtp_batch.use_mtp_head = true;
420-
mtp_batch.is_mtp_prompt_warmup = is_prompt_warmup;
422+
if (is_prompt_warmup) {
423+
mtp_batch.mtp_params.op_type = MTP_OP_WARMUP;
424+
} else {
425+
mtp_batch.mtp_params.op_type = MTP_OP_UPDATE_ACCEPTED;
426+
}
421427

422428
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
423429
mtp_batch.logits[i] = false;

include/llama.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,17 @@ extern "C" {
221221
// - if not: only the last token is output
222222
// )
223223
//
224+
typedef enum {
225+
MTP_OP_NONE,
226+
MTP_OP_WARMUP,
227+
MTP_OP_UPDATE_ACCEPTED,
228+
MTP_OP_DRAFT_GEN,
229+
} llama_mtp_op_type;
230+
231+
typedef struct llama_mtp_params {
232+
llama_mtp_op_type op_type;
233+
} llama_mtp_params;
234+
224235
typedef struct llama_batch {
225236
int32_t n_tokens;
226237

@@ -230,9 +241,7 @@ extern "C" {
230241
int32_t * n_seq_id;
231242
llama_seq_id ** seq_id;
232243
int8_t * logits; // TODO: rename this to "output"
233-
bool update_mtp_kv;
234-
bool use_mtp_head;
235-
bool is_mtp_prompt_warmup;
244+
llama_mtp_params mtp_params;
236245
} llama_batch;
237246

238247
enum llama_model_kv_override_type {

src/llama-batch.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -841,9 +841,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
841841
/*n_seq_id =*/ nullptr,
842842
/*seq_id =*/ nullptr,
843843
/*logits =*/ nullptr,
844-
/*.use_mtp_head =*/ false,
845-
/*update_mtp_kv =*/ false,
846-
/*.is_mtp_prompt_warmup =*/ false,
844+
/*.mtp_params =*/ { MTP_OP_NONE },
847845
};
848846

849847
if (embd) {

src/llama-context.cpp

Lines changed: 51 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,7 @@ static double calculate_vector_sum(const float* vec, size_t size) {
750750
}
751751

752752
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret,
753-
bool do_mtp_kv_update, bool use_mtp_head, bool is_mtp_prompt_warmup) {
753+
const llama_mtp_params & mtp_params) {
754754
if (mctx && !mctx->apply()) {
755755
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
756756
ret = GGML_STATUS_FAILED;
@@ -762,7 +762,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
762762

763763
// the new graph parameters
764764
// in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
765-
const auto gparams = graph_params(res, ubatch, mctx, gtype, do_mtp_kv_update, use_mtp_head);
765+
const auto gparams = graph_params(res, ubatch, mctx, gtype, mtp_params);
766766

767767
if (!graph_reuse_disable && res->can_reuse(gparams)) {
768768
//LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
@@ -793,22 +793,22 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
793793
}
794794
}
795795

796-
if (do_mtp_kv_update || (use_mtp_head && !do_mtp_kv_update)) { // If it is any MTP operation
796+
if (mtp_params.op_type != MTP_OP_NONE) { // If it is any MTP operation
797797
const char * target_tensor_name = "result_embd_pooled";
798798
ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name);
799799

800800
const float * source_hidden_state = nullptr;
801-
if (is_mtp_prompt_warmup || (do_mtp_kv_update && !is_mtp_prompt_warmup)) {
801+
if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
802802
source_hidden_state = this->embd;
803803
} else {
804804
source_hidden_state = this->draft_input_hidden_state;
805805
}
806806

807807
if (source_hidden_state != nullptr && hidden_states_input != nullptr) {
808808
const size_t n_embd = this->model.hparams.n_embd;
809-
const size_t n_tokens_for_sum = (do_mtp_kv_update && ubatch.n_tokens > 2) ? ubatch.n_tokens : 1;
809+
const size_t n_tokens_for_sum = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED && ubatch.n_tokens > 2) ? ubatch.n_tokens : 1;
810810
double input_sum = calculate_vector_sum(source_hidden_state, n_tokens_for_sum * n_embd);
811-
const char * op_type = (do_mtp_kv_update) ? "MTP_UPDATE" : "DRAFT_GEN";
811+
const char * op_type = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) ? "MTP_UPDATE" : "DRAFT_GEN";
812812

813813
LLAMA_LOG_WARN("[MTP-INPUT-CHECK] Operation: %s | Input Checksum: %e\n", op_type, input_sum);
814814

@@ -833,20 +833,20 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
833833
const int64_t t_exec_start_us = ggml_time_us();
834834
const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
835835
const int64_t t_exec_end_us = ggml_time_us();
836-
LLAMA_LOG_INFO(
837-
"[PERF] Graph compute time: %.2f ms (ubatch_size: %u, MTP path: %s)\n",
838-
(t_exec_end_us - t_exec_start_us) / 1000.0,
839-
ubatch.n_tokens,
840-
do_mtp_kv_update ? "yes" : "no"
841-
);
836+
// LLAMA_LOG_INFO(
837+
// "[PERF] Graph compute time: %.2f ms (ubatch_size: %u, MTP path: %s)\n",
838+
// (t_exec_end_us - t_exec_start_us) / 1000.0,
839+
// ubatch.n_tokens,
840+
// do_mtp_kv_update ? "yes" : "no"
841+
// );
842842
if (status != GGML_STATUS_SUCCESS) {
843843
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
844844
ret = status;
845845
return nullptr;
846846
}
847847

848848
ret = GGML_STATUS_SUCCESS;
849-
if (do_mtp_kv_update || use_mtp_head) {
849+
if (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
850850
ggml_tensor * sum_tensor = ggml_get_tensor(res->get_ctx(), "mtp_input_sum");
851851
if (sum_tensor) {
852852
LLAMA_LOG_WARN("[DEBUG-SUM] MTP input sum node successfully created.\n");
@@ -912,7 +912,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
912912
cparams.causal_attn = false;
913913

914914
ggml_status status;
915-
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, false, false, false);
915+
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, { MTP_OP_NONE });
916916

917917
cparams.causal_attn = causal_attn_org;
918918

@@ -1027,10 +1027,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
10271027
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
10281028

10291029
auto * kvd = static_cast<llama_context_kv_cache_data *>(kv_cache_data);
1030-
LLAMA_LOG_WARN("[DEBUG-DECODE-ENTRY] Entering llama_decode. update_mtp_kv=%s, use_mtp_head=%s\n",
1031-
batch_inp.update_mtp_kv ? "true" : "false",
1032-
batch_inp.use_mtp_head ? "true" : "false"
1033-
);
1030+
// LLAMA_LOG_WARN("[DEBUG-DECODE-ENTRY] Entering llama_decode. update_mtp_kv=%s, use_mtp_head=%s\n",
1031+
// batch_inp.update_mtp_kv ? "true" : "false",
1032+
// batch_inp.use_mtp_head ? "true" : "false"
1033+
// );
10341034

10351035
if (!memory) {
10361036
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
@@ -1101,7 +1101,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
11011101
} else {
11021102
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
11031103

1104-
if (!batch_inp.use_mtp_head && !batch_inp.update_mtp_kv) {
1104+
if (batch_inp.mtp_params.op_type == MTP_OP_NONE) {
11051105
if (mctx && mctx->get_status() == LLAMA_MEMORY_STATUS_SUCCESS) {
11061106
kvd->last_main_model_sinfos = static_cast<llama_kv_cache_unified_context *>(mctx.get())->get_sinfos();
11071107
} else {
@@ -1158,9 +1158,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
11581158
};
11591159

11601160
int64_t n_outputs_prev = 0;
1161-
const bool do_mtp_kv_update = batch_inp.update_mtp_kv;
1162-
const bool use_mtp_head = batch_inp.use_mtp_head;
1163-
const bool is_prompt_warmup = batch_inp.is_mtp_prompt_warmup;
1161+
// const bool do_mtp_kv_update = batch_inp.update_mtp_kv;
1162+
// const bool use_mtp_head = batch_inp.use_mtp_head;
1163+
// const bool is_prompt_warmup = batch_inp.is_mtp_prompt_warmup;
11641164

11651165
do {
11661166
const auto & ubatch = mctx->get_ubatch();
@@ -1169,13 +1169,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
11691169
for (uint32_t i = 0; i < std::min((uint32_t)5, ubatch.n_tokens); ++i) {
11701170
pos_str += std::to_string(ubatch.pos[i]) + " ";
11711171
}
1172-
LLAMA_LOG_WARN(
1173-
"[DEBUG-POS] ubatch_size=%u, update_mtp_kv=%s, use_mtp_head=%s. Positions: %s...\n",
1174-
ubatch.n_tokens,
1175-
batch_inp.update_mtp_kv ? "true" : "false",
1176-
batch_inp.use_mtp_head ? "true" : "false",
1177-
pos_str.c_str()
1178-
);
1172+
// LLAMA_LOG_WARN(
1173+
// "[DEBUG-POS] ubatch_size=%u, update_mtp_kv=%s, use_mtp_head=%s. Positions: %s...\n",
1174+
// ubatch.n_tokens,
1175+
// batch_inp.update_mtp_kv ? "true" : "false",
1176+
// batch_inp.use_mtp_head ? "true" : "false",
1177+
// pos_str.c_str()
1178+
// );
11791179
}
11801180

11811181
// count the outputs in this ubatch
@@ -1193,16 +1193,16 @@ int llama_context::decode(const llama_batch & batch_inp) {
11931193
// needs to happen before the graph is built
11941194
n_outputs = n_outputs_new;
11951195
}
1196-
if (do_mtp_kv_update) {
1197-
LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] MTP KV Update ubatch: n_tokens=%d\n", ubatch.n_tokens);
1198-
std::string positions_str;
1199-
for (int i = 0; i < std::min((uint32_t)5, ubatch.n_tokens); ++i) {
1200-
positions_str += std::to_string(ubatch.pos[i]) + " ";
1201-
}
1202-
LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] Positions: %s...\n", positions_str.c_str());
1203-
}
1196+
// if (do_mtp_kv_update) {
1197+
// LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] MTP KV Update ubatch: n_tokens=%d\n", ubatch.n_tokens);
1198+
// std::string positions_str;
1199+
// for (int i = 0; i < std::min((uint32_t)5, ubatch.n_tokens); ++i) {
1200+
// positions_str += std::to_string(ubatch.pos[i]) + " ";
1201+
// }
1202+
// LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] Positions: %s...\n", positions_str.c_str());
1203+
// }
12041204
ggml_status status;
1205-
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update, use_mtp_head, is_prompt_warmup);
1205+
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, batch_inp.mtp_params);
12061206
if (!res) {
12071207
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
12081208
llama_pos pos_min[LLAMA_MAX_SEQ];
@@ -1261,17 +1261,17 @@ int llama_context::decode(const llama_batch & batch_inp) {
12611261
}
12621262
}
12631263

1264-
if (use_mtp_head) {
1265-
if (t_embd != nullptr) {
1266-
LLAMA_LOG_ERROR("[MTP-GRAPH-BUG] The MTP graph returned an embedding tensor when it shouldn't have! This will cause corruption.\n");
1267-
} else {
1268-
LLAMA_LOG_WARN("[MTP-GRAPH-OK] The MTP graph correctly did not return an embedding tensor.\n");
1269-
}
1270-
}
1264+
// if (use_mtp_head) {
1265+
// if (t_embd != nullptr) {
1266+
// LLAMA_LOG_ERROR("[MTP-GRAPH-BUG] The MTP graph returned an embedding tensor when it shouldn't have! This will cause corruption.\n");
1267+
// } else {
1268+
// LLAMA_LOG_WARN("[MTP-GRAPH-OK] The MTP graph correctly did not return an embedding tensor.\n");
1269+
// }
1270+
// }
12711271

12721272
// extract embeddings
12731273
if (t_embd && n_outputs > 0) {
1274-
if (!use_mtp_head) {
1274+
if (batch_inp.mtp_params.op_type == MTP_OP_NONE) {
12751275
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
12761276
GGML_ASSERT(backend_embd != nullptr);
12771277

@@ -1389,7 +1389,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
13891389
ggml_backend_sched_reset(sched.get());
13901390
}
13911391

1392-
if (!use_mtp_head) {
1392+
if (batch_inp.mtp_params.op_type == MTP_OP_NONE) {
13931393
synchronize();
13941394
const size_t n_embd = this->model.hparams.n_embd;
13951395
double full_buffer_sum = calculate_vector_sum(this->embd, n_outputs_all * n_embd);
@@ -1534,7 +1534,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
15341534

15351535
auto * res = gf_res_reserve.get();
15361536

1537-
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, false, false);
1537+
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, { MTP_OP_NONE });
15381538

15391539
res->reset();
15401540

@@ -1556,8 +1556,7 @@ llm_graph_params llama_context::graph_params(
15561556
const llama_ubatch & ubatch,
15571557
const llama_memory_context_i * mctx,
15581558
llm_graph_type gtype,
1559-
bool update_mtp_kv,
1560-
bool use_mtp_head) const {
1559+
const llama_mtp_params & mtp_params) const {
15611560
return {
15621561
/*.arch =*/ model.arch,
15631562
/*.hparams =*/ model.hparams,
@@ -1570,8 +1569,7 @@ llm_graph_params llama_context::graph_params(
15701569
/*.loras =*/ &loras,
15711570
/*.mctx =*/ mctx,
15721571
/*.cross =*/ &cross,
1573-
/*.update_mtp_kv =*/ update_mtp_kv,
1574-
/*.use_mtp_head =*/ use_mtp_head,
1572+
/*.mtp_params =*/ mtp_params,
15751573
/*.n_outputs =*/ n_outputs,
15761574
/*.cb =*/ graph_get_cb(),
15771575
/*.res =*/ res,
@@ -2312,7 +2310,7 @@ void llama_context::opt_epoch_iter(
23122310

23132311
auto * res = gf_res_prev.get();
23142312

2315-
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT, false, false);
2313+
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT, { MTP_OP_NONE });
23162314

23172315
res->reset();
23182316

src/llama-context.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,7 @@ struct llama_context {
111111
llm_graph_type gtype,
112112
llama_memory_context_i * mctx,
113113
ggml_status & ret,
114-
const bool do_mtp_kv_update,
115-
const bool use_mtp_head,
116-
bool is_mtp_prompt_warmup);
114+
const llama_mtp_params & mtp_params);
117115

118116
int encode(const llama_batch & batch_inp);
119117
int decode(const llama_batch & batch_inp);
@@ -229,8 +227,7 @@ struct llama_context {
229227
const llama_ubatch & ubatch,
230228
const llama_memory_context_i * mctx,
231229
llm_graph_type gtype,
232-
bool update_mtp_kv,
233-
bool use_mtp_head) const;
230+
const llama_mtp_params & mtp_params) const;
234231

235232
llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const;
236233

src/llama-graph.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -417,8 +417,7 @@ struct llm_graph_params {
417417
const llama_adapter_loras * loras;
418418
const llama_memory_context_i * mctx;
419419
const llama_cross * cross;
420-
bool update_mtp_kv;
421-
bool use_mtp_head;
420+
llama_mtp_params mtp_params;
422421

423422
uint32_t n_outputs;
424423

@@ -467,8 +466,7 @@ struct llm_graph_params {
467466
cvec == other.cvec &&
468467
loras == other.loras &&
469468
cross == other.cross &&
470-
update_mtp_kv == other.update_mtp_kv &&
471-
use_mtp_head == other.use_mtp_head &&
469+
mtp_params.op_type == other.mtp_params.op_type &&
472470
n_outputs == other.n_outputs;
473471
}
474472
};

0 commit comments

Comments
 (0)