Skip to content

Commit 6870f97

Browse files
committed
added proper KV cache management for MTP layers and slightly refactored
1 parent 6e9bafc commit 6870f97

File tree

11 files changed

+136
-94
lines changed

11 files changed

+136
-94
lines changed

common/speculative.cpp

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -370,25 +370,45 @@ llama_token mtp_speculative_gen_draft(
370370
int32_t n_past,
371371
int32_t last_tok_idx) {
372372

373-
const auto * model = llama_get_model(ctx);
374-
auto * last_embd = llama_get_embeddings_tensor(ctx);
373+
llama_token token_data[] = { id_last };
374+
llama_pos pos_data[] = { n_past };
375+
int32_t n_seq_id_data[] = { 1 };
376+
llama_seq_id seq_id_data_internal[] = { 0 };
377+
llama_seq_id* seq_id_data[] = {seq_id_data_internal};
378+
int8_t logits_data[] = { (int8_t) (smpl != nullptr) };
379+
380+
llama_batch batch = {
381+
/*.n_tokens = */ 1,
382+
/*.token = */ token_data,
383+
/*.embd = */ nullptr,
384+
/*.pos = */ pos_data,
385+
/*.n_seq_id = */ n_seq_id_data,
386+
/*.seq_id = */ seq_id_data,
387+
/*.logits = */ logits_data
388+
};
389+
390+
llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
391+
//LOG_INF("updating kv cache for n_past: %d\n", n_past);
375392

376-
GGML_ASSERT(model != nullptr);
377-
GGML_ASSERT(last_embd != nullptr);
378-
llama_build_and_execute_mtp_graph(ctx, last_embd, id_last, n_past, last_tok_idx);
393+
if (!smpl) {
394+
return -1;
395+
}
396+
else {
397+
common_sampler_sample(smpl, ctx, last_tok_idx, true);
398+
const auto* cur_p = common_sampler_get_candidates(smpl);
379399

380-
common_sampler_sample(smpl, ctx, last_tok_idx, true);
400+
//for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) {
401+
// LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
402+
// k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
403+
//}
381404

382-
const auto* cur_p = common_sampler_get_candidates(smpl);
383-
/*LOG_INF("cur_p->size: %d\n", cur_p->size);
405+
const llama_token id = cur_p->data[0].id;
406+
return id;
407+
}
408+
// LOG_INF("cur_p->size: %d\n", cur_p->size);
384409

385-
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
386-
LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
387-
k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
388-
}*/
389410

390411
// add drafted token for each sequence
391-
const llama_token id = cur_p->data[0].id;
392412

393413
// skip accepting draft token -- since we're only drafting one token this can't affect future outputs
394414
// smpl will accept the token if it doesn't get rejected by main model later
@@ -398,5 +418,15 @@ llama_token mtp_speculative_gen_draft(
398418
//result.reserve(1);
399419
//result.push_back(id);
400420
//return result;
401-
return id;
421+
}
422+
423+
424+
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens) {
425+
mtp_kv_update_data token;
426+
for (int i = 0; i < tokens.size(); ++i) {
427+
token = tokens[i];
428+
mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx);
429+
}
430+
431+
tokens.clear();
402432
}

common/speculative.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ struct common_speculative_params {
1212
float p_min = 0.75f; // min probability required to accept a token in the draft
1313
};
1414

15+
struct mtp_kv_update_data {
16+
llama_token id;
17+
int32_t n_past;
18+
int32_t tok_idx;
19+
};
20+
1521
struct common_speculative * common_speculative_init(
1622
struct llama_context * ctx_tgt,
1723
struct llama_context * ctx_dft
@@ -42,3 +48,5 @@ llama_tokens common_speculative_gen_draft(
4248
struct common_speculative_params params,
4349
const llama_tokens & prompt,
4450
llama_token id_last);
51+
52+
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens);

include/llama.h

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -544,9 +544,6 @@ extern "C" {
544544
// Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
545545
LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);
546546

547-
LLAMA_API ggml_cgraph * llama_build_mtp_graph(const struct llama_model * model, const struct llm_graph_params & params,
548-
struct ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past);
549-
550547
// Returns 0 on success
551548
LLAMA_API uint32_t llama_model_quantize(
552549
const char * fname_inp,
@@ -999,8 +996,6 @@ extern "C" {
999996
// otherwise: float[n_embd] (1-dimensional)
1000997
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
1001998

1002-
LLAMA_API ggml_tensor * llama_get_embeddings_tensor(struct llama_context * ctx);
1003-
1004999
//
10051000
// Vocab
10061001
//
@@ -1459,16 +1454,8 @@ extern "C" {
14591454
ggml_opt_epoch_callback callback_train,
14601455
ggml_opt_epoch_callback callback_eval);
14611456

1462-
LLAMA_API llm_graph_params llama_mtp_graph_params(struct llama_context* ctx, class llm_graph_result * res, const struct llama_ubatch& ubatch);
1463-
1464-
LLAMA_API ggml_status llama_graph_compute(struct llama_context * ctx, struct ggml_cgraph * gf, bool batched);
1465-
14661457
LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
1467-
ggml_tensor* hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
1468-
1469-
LLAMA_API ggml_tensor * llama_graph_result_get_logits(class llm_graph_result * res);
1470-
1471-
1458+
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
14721459

14731460
#ifdef __cplusplus
14741461
}

src/llama-batch.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,9 @@ bool llama_batch_allocr::init(
275275
}
276276
}
277277

278-
if (!ok) {
278+
// TEMPORARILY DISABLING THIS SANITY CHECK
279+
// TODO: UNDO THIS IF IT WORKS
280+
/*if (!ok) {
279281
LLAMA_LOG_ERROR(
280282
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
281283
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
@@ -284,7 +286,7 @@ bool llama_batch_allocr::init(
284286
__func__, s, s, p0, s, seq_pos_min(s));
285287
286288
return false;
287-
}
289+
}*/
288290
}
289291

290292
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {

src/llama-context.cpp

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1448,8 +1448,9 @@ llm_graph_params llama_context::graph_params(
14481448
}
14491449

14501450
llm_graph_params llama_context::mtp_graph_params(
1451-
llm_graph_result* res,
1452-
const llama_ubatch& ubatch) {
1451+
llm_graph_result * res,
1452+
const llama_ubatch& ubatch,
1453+
const llama_memory_context_i * mctx) {
14531454
size_t n_nodes = std::max<uint32_t>(1024u, 8u * 8u * (((model.hparams.nextn_predict_layers + 1) * model.n_tensors()) / model.hparams.n_layer));
14541455
ggml_backend_sched_t temp_sched = create_temp_scheduler(n_nodes);
14551456
return {
@@ -1462,14 +1463,29 @@ llm_graph_params llama_context::mtp_graph_params(
14621463
/*.backend_cpu =*/ backend_cpu,
14631464
/*.cvec =*/ &cvec,
14641465
/*.loras =*/ &loras,
1465-
/*.mctx =*/ memory->init_batch(*balloc, 1, false).get(),
1466+
/*.mctx =*/ mctx,
14661467
/*.cross =*/ &cross,
14671468
/*.n_outputs =*/ 1,
14681469
/*.cb =*/ graph_get_cb(temp_sched),
14691470
/*.res =*/ res,
14701471
};
14711472
}
14721473

1474+
std::unique_ptr<llama_memory_context_i> llama_context::mtp_memory_batch(const llama_batch& batch_inp) {
1475+
const auto& vocab = model.vocab;
1476+
const auto& hparams = model.hparams;
1477+
1478+
const int64_t n_vocab = vocab.n_tokens();
1479+
const int64_t n_embd = hparams.n_embd;
1480+
1481+
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, false)) {
1482+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
1483+
return nullptr;
1484+
}
1485+
1486+
return memory->init_batch(*balloc, 1, false);
1487+
}
1488+
14731489
ggml_status llama_context::graph_compute(
14741490
ggml_cgraph * gf,
14751491
bool batched) {
@@ -2481,13 +2497,6 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
24812497
return ctx->get_embeddings_seq(seq_id);
24822498
}
24832499

2484-
ggml_tensor * llama_get_embeddings_tensor(llama_context * ctx) {
2485-
ctx->synchronize();
2486-
2487-
return ctx->get_embeddings_tensor();
2488-
}
2489-
2490-
24912500
// llama adapter API
24922501

24932502
int32_t llama_set_adapter_lora(
@@ -2985,42 +2994,43 @@ void llama_opt_epoch(
29852994
callback_eval);
29862995
}
29872996

2988-
llm_graph_params llama_mtp_graph_params(llama_context* ctx, llm_graph_result* res, const llama_ubatch& ubatch) {
2989-
return ctx->mtp_graph_params(res, ubatch);
2990-
}
2991-
2992-
2993-
ggml_status llama_graph_compute(llama_context* ctx, ggml_cgraph* gf, bool batched) {
2994-
return ctx->graph_compute(gf, batched);
2995-
}
2996-
29972997
void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
2998-
ggml_tensor * hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
2998+
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
29992999

30003000
const auto * model = llama_get_model(ctx);
30013001

30023002
auto res_mtp = std::make_unique<llm_graph_result>(ctx->graph_max_nodes());
3003+
llama_memory_context_ptr mctx = ctx->mtp_memory_batch(batch_inp);
3004+
const auto& ubatch_mtp = mctx->get_ubatch();
30033005

3004-
llama_ubatch ubatch_mtp;
3005-
ubatch_mtp.n_tokens = 1;
3006-
ubatch_mtp.pos = &n_past;
3006+
//llama_ubatch ubatch_mtp;
3007+
//ubatch_mtp.n_tokens = 1;
3008+
//ubatch_mtp.pos = &n_past;
30073009

3008-
auto params_mtp = std::make_unique<llm_graph_params>(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp));
3010+
auto params_mtp = std::make_unique<llm_graph_params>(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp, mctx.get()));
3011+
ggml_backend_sched_t sched = params_mtp->sched;
30093012

3010-
auto* gf = model->build_mtp_graph(*params_mtp, hidden_state_inp, last_token_id, n_past);
3013+
auto * last_embd = ctx->get_embeddings_ith(last_tok_idx);
30113014

3012-
ggml_backend_sched_t sched = params_mtp->sched;
3015+
if (mctx && !mctx->apply()) {
3016+
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
3017+
}
3018+
3019+
auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past);
30133020

30143021
ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
30153022
ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it
30163023

30173024
ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input");
3018-
30193025
ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors
3026+
3027+
ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embedding_input");
3028+
ggml_backend_tensor_set(mtp_prev_embedding_input, last_embd, 0, ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors
3029+
30203030
ggml_backend_sched_graph_compute(sched, gf); // execute the graph
30213031

30223032
struct ggml_tensor * logits_mtp = res_mtp->get_logits();;
3023-
LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp);
3033+
//LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp);
30243034

30253035
if (logits_mtp) {
30263036
ctx->set_logits_ith(logits_mtp, sched, last_tok_idx);

src/llama-context.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,12 +200,14 @@ struct llama_context {
200200
// reserve a graph with a dummy ubatch of the specified size
201201
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
202202

203-
llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch);
203+
llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx);
204204

205205
void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i);
206206

207207
ggml_backend_sched_t create_temp_scheduler(size_t n_nodes);
208208

209+
std::unique_ptr<llama_memory_context_i> mtp_memory_batch(const llama_batch& batch_inp);
210+
209211
private:
210212
llm_graph_params graph_params(
211213
llm_graph_result * res,

src/llama-graph.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1911,7 +1911,3 @@ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buck
19111911

19121912
return relative_bucket;
19131913
}
1914-
1915-
ggml_tensor * llama_graph_result_get_logits(llm_graph_result * res) {
1916-
return res->get_logits();
1917-
}

src/llama-kv-cache-unified.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
4141
}
4242
if (model.arch == LLM_ARCH_GLM4_MOE) {
4343
// GLM-4.5: Only process up to last layer, skip final NextN layer
44-
n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers;
44+
n_layer_cache = hparams.n_layer;// - hparams.nextn_predict_layers;
4545
}
4646

4747
// create a context for each buffer type

src/llama-model.cpp

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13948,7 +13948,7 @@ struct llm_build_glm4_moe : public llm_graph_context {
1394813948
struct llm_build_glm4_moe_mtp : public llm_graph_context {
1394913949
llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params,
1395013950
// For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization
13951-
ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past
13951+
llama_token last_token_id, int n_past
1395213952
) : llm_graph_context(params) {
1395313953
const int64_t n_embd_head = hparams.n_embd_head_v;
1395413954
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -13961,7 +13961,8 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
1396113961
// ggml_set_i32(inp_pos, n_past);
1396213962
ggml_tensor * inp_pos = build_inp_pos();
1396313963

13964-
llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr;
13964+
//llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr;
13965+
auto * inp_attn = build_attn_inp_kv_unified();
1396513966

1396613967
ggml_tensor * cur;
1396713968

@@ -13982,9 +13983,9 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
1398213983
ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id);
1398313984
ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
1398413985

13985-
ggml_tensor * prev_embedding_leaf = ggml_dup_tensor(ctx0, hidden_state_inp);
13986-
ggml_set_name(prev_embedding_leaf, "mtp_prev_embedding_leaf");
13987-
ggml_cpy(ctx0, hidden_state_inp, prev_embedding_leaf);
13986+
ggml_tensor* prev_embedding_leaf = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, model.hparams.n_embd);
13987+
ggml_set_name(prev_embedding_leaf, "mtp_prev_embedding_input");
13988+
ggml_set_input(prev_embedding_leaf);
1398813989

1398913990
// vLLM l99 previous_hidden_states = self.hnorm(previous_hidden_states)
1399013991
ggml_tensor * hidden_state_norm = build_norm(prev_embedding_leaf, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
@@ -18693,13 +18694,13 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1869318694
}
1869418695

1869518696
ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params,
18696-
ggml_tensor* hidden_state_inp, llama_token last_token_id, int n_past) const {
18697+
llama_token last_token_id, int n_past) const {
1869718698
std::unique_ptr<llm_graph_context> llm;
1869818699

1869918700
switch (arch) {
1870018701
case LLM_ARCH_GLM4_MOE:
1870118702
{
18702-
llm = std::make_unique<llm_build_glm4_moe_mtp>(*this, params, hidden_state_inp, last_token_id, n_past);
18703+
llm = std::make_unique<llm_build_glm4_moe_mtp>(*this, params, last_token_id, n_past);
1870318704
} break;
1870418705
default:
1870518706
GGML_ABORT("fatal error");
@@ -19024,10 +19025,3 @@ const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_te
1902419025
return model->tensors_by_name;
1902519026
}
1902619027

19027-
ggml_cgraph * llama_build_mtp_graph(const llama_model * model, const llm_graph_params & params,
19028-
ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) {
19029-
19030-
return model->build_mtp_graph(params, hidden_state_inp, last_token_id, n_past);
19031-
}
19032-
19033-

src/llama-model.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ struct llama_model {
476476
// TODO: move this to new llm_arch_model_i interface
477477
ggml_cgraph * build_graph(const llm_graph_params & params) const;
478478
ggml_cgraph * build_mtp_graph(const llm_graph_params & params,
479-
ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const;
479+
llama_token last_token_id, int n_past) const;
480480

481481
private:
482482
struct impl;

0 commit comments

Comments
 (0)