diff --git a/common/speculative.cpp b/common/speculative.cpp index 77ed75913d5c7..a7a4042682184 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -373,48 +373,85 @@ llama_token mtp_speculative_gen_draft( if (!smpl) { return -1; } + llama_batch mtp_batch = llama_batch_init(1, 0, 1); + const llama_pos draft_pos = n_past; + const llama_seq_id draft_seq_id = 0; + common_batch_add(mtp_batch, id_last, n_past, {0}, true); - llama_batch batch = llama_batch_init(1, 0, 1); - common_batch_add(batch, id_last, n_past, {0}, true); + mtp_batch.mtp_params.op_type = MTP_OP_DRAFT_GEN; - llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx); + // Perform the MTP draft generation decode. This writes the MTP layer's + // KV state for the draft token into the cache. + llama_decode(ctx, mtp_batch); + llama_batch_free(mtp_batch); + + // CRITICAL: Purge the metadata for the draft token we just wrote. + // This makes the physical cell available again for the main model's validation pass, + // preventing a cache state corruption where two cells map to the same logical position. + llama_kv_cache_seq_rm(ctx, draft_seq_id, draft_pos, draft_pos + 1); const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); const int n_vocab = llama_n_vocab(vocab); - llama_token_data_array * cur_p = common_sampler_get_candidates(smpl); - cur_p->size = n_vocab; for (int i = 0; i < n_vocab; ++i) { cur_p->data[i].id = i; - cur_p->data[i].logit = llama_get_logits_ith(ctx, last_tok_idx)[i]; + cur_p->data[i].logit = llama_get_logits_ith(ctx, 0)[i]; // For a single-token batch, logits are always at index 0. } cur_p->sorted = false; - common_sampler_apply_chain(smpl, cur_p); + + return cur_p->data[0].id; +} - const llama_token id = cur_p->data[0].id; - llama_batch_free(batch); +void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) { + if (batch.n_tokens == 0) { + return; + } - return id; -} + LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens); + llama_batch mtp_batch = batch; + if (is_prompt_warmup) { + mtp_batch.mtp_params.op_type = MTP_OP_WARMUP; + } else { + mtp_batch.mtp_params.op_type = MTP_OP_UPDATE_ACCEPTED; + } -void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, size_t batch_start, size_t n_tokens) { - mtp_kv_update_data token; + for (int i = 0; i < mtp_batch.n_tokens; ++i) { + mtp_batch.logits[i] = true; + } + llama_decode(ctx, mtp_batch); +} - if (n_tokens < 0) { - n_tokens = tokens.size(); +void mtp_accept_tokens( + struct llama_context * ctx, + const std::vector & ids, + int32_t n_past_base, + llama_seq_id seq_id +) { + if (ids.empty()) { + return; } - for (int i = 0; i < std::min(tokens.size(), n_tokens); ++i) { - token = tokens[i]; - //fprintf(stderr, "updating mtp kv cache with token (%d, %d, %d)\n", token.id, token.n_past, (int) (token.tok_idx - batch_start)); + // Prepare a resized copy of the validation sinfo to match the number of accepted tokens. + // This sets up the context for a "forced sinfo" decode. + if (!llama_mtp_prepare_sinfo_for_update(ctx, ids.size())) { + return; + } - mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx - batch_start); + // Build a new batch containing only the accepted tokens. + llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1); + for (size_t i = 0; i < ids.size(); ++i) { + common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true); } - tokens.clear(); -} \ No newline at end of file + mtp_update_kv_cache(ctx, accepted_batch, false); + + // Clean up the forced state to not affect subsequent, normal decode calls. + llama_mtp_cancel_sinfo_update(ctx); + + llama_batch_free(accepted_batch); +} diff --git a/common/speculative.h b/common/speculative.h index 230f8382bccfc..8b81f4ac77df3 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -49,4 +49,11 @@ llama_tokens common_speculative_gen_draft( const llama_tokens & prompt, llama_token id_last); -void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, size_t batch_start = 0, size_t n_tokens = -1); +void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup); + +void mtp_accept_tokens( + struct llama_context * ctx, + const std::vector & ids, + int32_t n_past_base, + llama_seq_id seq_id +); diff --git a/include/llama.h b/include/llama.h index e43cd83468d0f..0b15d4bf1cd0d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -221,6 +221,17 @@ extern "C" { // - if not: only the last token is output // ) // + typedef enum { + MTP_OP_NONE, + MTP_OP_WARMUP, + MTP_OP_UPDATE_ACCEPTED, + MTP_OP_DRAFT_GEN, + } llama_mtp_op_type; + + typedef struct llama_mtp_params { + llama_mtp_op_type op_type; + } llama_mtp_params; + typedef struct llama_batch { int32_t n_tokens; @@ -230,6 +241,7 @@ extern "C" { int32_t * n_seq_id; llama_seq_id ** seq_id; int8_t * logits; // TODO: rename this to "output" + llama_mtp_params mtp_params; } llama_batch; enum llama_model_kv_override_type { @@ -1454,8 +1466,37 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); - LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx, - const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); + // + // MTP + // + + LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state); + + /** + * @brief Prepares the context for an MTP KV cache update by creating a resized copy of the last sinfo. + * This is used after speculative validation when only a subset of draft tokens are accepted. + * @param n_accepted The number of tokens that were accepted and for which the sinfo should be resized. + * @return true on success. + */ + LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted); + + /** + * @brief Prepares the context for an MTP KV cache update by reusing the sinfo from the last main model decode. + * This is used for the prompt warmup to ensure the MTP and main model KV caches are perfectly aligned. + * @return true on success. + */ + LLAMA_API bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx); + + /** + * @brief Clears the forced sinfo state from the context. Must be called after a decode that used a prepared sinfo. + */ + LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx); + + /** + * @brief Removes KV cache metadata for a specified sequence and token range. + * This makes the physical cells logically available again without deleting the tensor data. + */ + LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); #ifdef __cplusplus } diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index ff73429301d68..c01960c55ea94 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -834,13 +834,14 @@ struct llama_batch llama_batch_get_one( struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { llama_batch batch = { - /*n_tokens =*/ 0, - /*tokens =*/ nullptr, - /*embd =*/ nullptr, - /*pos =*/ nullptr, - /*n_seq_id =*/ nullptr, - /*seq_id =*/ nullptr, - /*logits =*/ nullptr, + /*n_tokens =*/ 0, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + /*.mtp_params =*/ { MTP_OP_NONE }, }; if (embd) { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index fb285a8d297c9..fb35d6c79debf 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -17,6 +17,11 @@ // // llama_context // +struct llama_context_kv_cache_data { + llama_kv_cache_unified::slot_info_vec_t last_main_model_sinfos; + llama_kv_cache_unified::slot_info_vec_t resized_sinfo_for_force; + const llama_kv_cache_unified::slot_info_vec_t * forced_sinfos = nullptr; +}; llama_context::llama_context( const llama_model & model, @@ -105,6 +110,8 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; + kv_cache_data = new llama_context_kv_cache_data(); + { const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : supports_set_rows; @@ -370,6 +377,7 @@ llama_context::llama_context( llama_context::~llama_context() { ggml_opt_free(opt_ctx); + delete static_cast(kv_cache_data); } void llama_context::synchronize() { @@ -729,7 +737,8 @@ bool llama_context::apply_adapter_cvec( return cvec.apply(model, data, len, n_embd, il_start, il_end); } -llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { +llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret, + const llama_mtp_params & mtp_params) { if (mctx && !mctx->apply()) { LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); ret = GGML_STATUS_FAILED; @@ -741,7 +750,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll // the new graph parameters // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters - const auto gparams = graph_params(res, ubatch, mctx, gtype); + const auto gparams = graph_params(res, ubatch, mctx, gtype, mtp_params); if (!graph_reuse_disable && res->can_reuse(gparams)) { //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); @@ -772,6 +781,13 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } } + if (mtp_params.op_type != MTP_OP_NONE) { // If it is any MTP operation + if (!prepare_mtp_graph_inputs(res, ubatch, mtp_params)) { + ret = GGML_STATUS_FAILED; + return nullptr; + } + } + // set the input data for the input tensors { //const auto t_start_us = ggml_time_us(); @@ -789,7 +805,9 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } ret = GGML_STATUS_SUCCESS; - + if (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { + ggml_tensor * sum_tensor = ggml_get_tensor(res->get_ctx(), "mtp_input_sum"); + } return res; } @@ -850,7 +868,7 @@ int llama_context::encode(const llama_batch & batch_inp) { cparams.causal_attn = false; ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, { MTP_OP_NONE }); cparams.causal_attn = causal_attn_org; @@ -964,6 +982,8 @@ int llama_context::encode(const llama_batch & batch_inp) { int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT + auto * kvd = static_cast(kv_cache_data); + if (!memory) { LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); return encode(batch_inp); @@ -1018,10 +1038,11 @@ int llama_context::decode(const llama_batch & batch_inp) { // handle any pending defrags/shifts kv_self_update(false); - llama_memory_context_ptr mctx; + std::unique_ptr mctx; while (true) { - mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + mctx = this->initialize_decode_context(batch_inp, output_all); + if (!mctx) { return -2; } @@ -1033,29 +1054,28 @@ int llama_context::decode(const llama_batch & batch_inp) { case LLAMA_MEMORY_STATUS_NO_UPDATE: { LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status()); - return -2; } case LLAMA_MEMORY_STATUS_FAILED_PREPARE: { + if (kvd->forced_sinfos) { + LLAMA_LOG_ERROR("%s: Mismatch between ubatches and sinfos during reuse.\n", __func__); + return -1; + } + if (!did_optimize) { did_optimize = true; - if (kv_self_update(true)) { LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens()); - continue; } } - LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens()); - return 1; } case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: { LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens()); - return -2; } } @@ -1070,10 +1090,9 @@ int llama_context::decode(const llama_batch & batch_inp) { }; int64_t n_outputs_prev = 0; - + do { const auto & ubatch = mctx->get_ubatch(); - // count the outputs in this ubatch { int32_t n_outputs_new = 0; @@ -1089,10 +1108,8 @@ int llama_context::decode(const llama_batch & batch_inp) { // needs to happen before the graph is built n_outputs = n_outputs_new; } - ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); - + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, batch_inp.mtp_params); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache llama_pos pos_min[LLAMA_MAX_SEQ]; @@ -1131,7 +1148,6 @@ int llama_context::decode(const llama_batch & batch_inp) { auto * t_logits = res->get_logits(); auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; - embd_tensor = res->get_embd(); if (t_embd && res->get_embd_pooled()) { t_embd = res->get_embd_pooled(); @@ -1139,71 +1155,81 @@ int llama_context::decode(const llama_batch & batch_inp) { // extract logits if (t_logits && n_outputs > 0) { - ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); - GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); - - float * logits_out = logits + n_outputs_prev*n_vocab; - - if (n_outputs) { - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); - ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + // MTP operations that are purely for updating the KV cache + // (MTP_OP_WARMUP and MTP_OP_UPDATE_ACCEPTED) also produce a logit tensor + // as a side effect of running the graph. If these logits are copied + // back to the main context buffer, they will overwrite the valid logits + // produced by the main model's pass, leading to incorrect sampling. + // This condition explicitly prevents that copy for cache-only operations. + if (batch_inp.mtp_params.op_type != MTP_OP_WARMUP && + batch_inp.mtp_params.op_type != MTP_OP_UPDATE_ACCEPTED) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + float * logits_out = logits + n_outputs_prev*n_vocab; + + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); + ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + } } } // extract embeddings if (t_embd && n_outputs > 0) { - ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); - GGML_ASSERT(backend_embd != nullptr); - - switch (cparams.pooling_type) { - case LLAMA_POOLING_TYPE_NONE: - { - // extract token embeddings - GGML_ASSERT(embd != nullptr); - float * embd_out = embd + n_outputs_prev*n_embd; - - if (n_outputs) { - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_MEAN: - case LLAMA_POOLING_TYPE_CLS: - case LLAMA_POOLING_TYPE_LAST: - { - // extract sequence embeddings (cleared before processing each batch) - auto & embd_seq_out = embd_seq; - - for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { - const llama_seq_id seq_id = ubatch.seq_id_unq[s]; - const int32_t seq_idx = ubatch.seq_idx[seq_id]; - - embd_seq_out[seq_id].resize(n_embd); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_RANK: - { - // extract the rerank score - n_cls_out floats per sequence - auto & embd_seq_out = embd_seq; - - const uint32_t n_cls_out = hparams.n_cls_out; - - for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { - const llama_seq_id seq_id = ubatch.seq_id_unq[s]; - const int32_t seq_idx = ubatch.seq_idx[seq_id]; - - embd_seq_out[seq_id].resize(n_cls_out); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); + if (batch_inp.mtp_params.op_type == MTP_OP_NONE) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); + GGML_ASSERT(backend_embd != nullptr); + + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + GGML_ASSERT(embd != nullptr); + float * embd_out = embd + n_outputs_prev*n_embd; + + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings (cleared before processing each batch) + auto & embd_seq_out = embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const llama_seq_id seq_id = ubatch.seq_id_unq[s]; + const int32_t seq_idx = ubatch.seq_idx[seq_id]; + + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_RANK: + { + // extract the rerank score - n_cls_out floats per sequence + auto & embd_seq_out = embd_seq; + const uint32_t n_cls_out = hparams.n_cls_out; + + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const llama_seq_id seq_id = ubatch.seq_id_unq[s]; + const int32_t seq_idx = ubatch.seq_idx[seq_id]; + + embd_seq_out[seq_id].resize(n_cls_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); } - } break; - case LLAMA_POOLING_TYPE_UNSPECIFIED: - { - GGML_ABORT("unknown pooling type"); - } + } } } @@ -1268,7 +1294,6 @@ int llama_context::decode(const llama_batch & batch_inp) { // overlap with device computation. ggml_backend_sched_reset(sched.get()); } - return 0; } @@ -1408,7 +1433,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u auto * res = gf_res_reserve.get(); - const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, { MTP_OP_NONE }); res->reset(); @@ -1428,8 +1453,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u llm_graph_params llama_context::graph_params( llm_graph_result * res, const llama_ubatch & ubatch, - const llama_memory_context_i * mctx, - llm_graph_type gtype) const { + const llama_memory_context_i * mctx, + llm_graph_type gtype, + const llama_mtp_params & mtp_params) const { return { /*.arch =*/ model.arch, /*.hparams =*/ model.hparams, @@ -1442,36 +1468,13 @@ llm_graph_params llama_context::graph_params( /*.loras =*/ &loras, /*.mctx =*/ mctx, /*.cross =*/ &cross, + /*.mtp_params =*/ mtp_params, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), /*.res =*/ res, }; } -llm_graph_params llama_context::mtp_graph_params( - llm_graph_result * res, - const llama_ubatch& ubatch, - const llama_memory_context_i * mctx) { - size_t n_nodes = std::max(1024u, 8u * 8u * (((model.hparams.nextn_predict_layers + 1) * model.n_tensors()) / model.hparams.n_layer)); - ggml_backend_sched_t temp_sched = create_temp_scheduler(n_nodes); - return { - /*.arch =*/ model.arch, - /*.hparams =*/ model.hparams, - /*.cparams =*/ cparams, - /*.ubatch =*/ ubatch, - /*.gtype =*/ LLM_GRAPH_TYPE_DECODER, - /*.sched =*/ temp_sched, - /*.backend_cpu =*/ backend_cpu, - /*.cvec =*/ &cvec, - /*.loras =*/ &loras, - /*.mctx =*/ mctx, - /*.cross =*/ &cross, - /*.n_outputs =*/ 1, - /*.cb =*/ graph_get_cb(temp_sched), - /*.res =*/ res, - }; -} - std::unique_ptr llama_context::mtp_memory_batch(const llama_batch& batch_inp) { const auto& vocab = model.vocab; const auto& hparams = model.hparams; @@ -2206,7 +2209,7 @@ void llama_context::opt_epoch_iter( auto * res = gf_res_prev.get(); - const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT, { MTP_OP_NONE }); res->reset(); @@ -2995,79 +2998,121 @@ void llama_opt_epoch( callback_eval); } -void llama_build_and_execute_mtp_graph(struct llama_context * ctx, - const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { +void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state) { + ctx->draft_input_hidden_state = hidden_state; +} - const auto * model = llama_get_model(ctx); +bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx) { + auto * kvd = static_cast(ctx->kv_cache_data); + const auto & last_sinfo = kvd->last_main_model_sinfos; - auto res_mtp = std::make_unique(ctx->graph_max_nodes()); - std::unique_ptr mctx = ctx->mtp_memory_batch(batch_inp); + if (last_sinfo.empty()) { + LLAMA_LOG_ERROR("%s: The main call sinfo is not available for warmup.\n", __func__); + return false; + } - std::vector idxs; - idxs.push_back(n_past); - llama_kv_cache_unified::slot_info sinfo = { - /*.s0 =*/ 0, - /*.s1 =*/ 0, - /*.strm =*/ { 0 }, - /*.idxs =*/ { idxs }, - }; - llama_kv_cache_unified::slot_info_vec_t sinfos; - sinfos.push_back(sinfo); + kvd->forced_sinfos = &last_sinfo; + return true; +} - static_cast(mctx.get())->set_sinfos(sinfos); - const auto& ubatch_mtp = mctx->get_ubatch(); - //llama_ubatch ubatch_mtp; - //ubatch_mtp.n_tokens = 1; - //ubatch_mtp.pos = &n_past; +bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted) { + auto * kvd = static_cast(ctx->kv_cache_data); + const auto & last_sinfo = kvd->last_main_model_sinfos; + + if (last_sinfo.empty() || last_sinfo[0].idxs.empty()) { + LLAMA_LOG_ERROR("%s: The sinfo for the last main call is not available.", __func__); + return false; + } - auto params_mtp = std::make_unique(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp, mctx.get())); - ggml_backend_sched_t sched = params_mtp->sched; + kvd->resized_sinfo_for_force = last_sinfo; + + kvd->resized_sinfo_for_force[0].idxs[0].resize(n_accepted); - auto * last_embd = ctx->get_embeddings_ith(last_tok_idx); + kvd->forced_sinfos = &kvd->resized_sinfo_for_force; - //if (mctx && !mctx->set_n_kv()) { - // LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); - //} - static_cast(mctx.get())->set_n_kv(); + return true; +} - auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past); +void llama_mtp_cancel_sinfo_update(struct llama_context * ctx) { + auto * kvd = static_cast(ctx->kv_cache_data); + kvd->forced_sinfos = nullptr; +} - if (!gf) { - LLAMA_LOG_ERROR("%s: ERROR - The construction of the MTP graph failed (returned null).", __func__); - if (sched) ggml_backend_sched_free(sched); - return; +void llama_context::kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + if (memory) { + static_cast(memory.get())->seq_rm(seq_id, p0, p1); } +} + +void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + ctx->kv_cache_seq_rm(seq_id, p0, p1); +} + +/* + Initializes the memory context for a decode operation. + The logic follows a specific priority: + 1. Warmup: Always use a standard batch initialization. + 2. Forced S-Info (MTP Updates): If a specific KV cache layout is forced, use it. + 3. Default: Use a standard batch initialization, and if it's a main model pass, + save the resulting s-info for potential future reuse by MTP. +*/ +std::unique_ptr llama_context::initialize_decode_context(const llama_batch & batch_inp, const bool output_all) { + auto * kvd = static_cast(kv_cache_data); + std::unique_ptr mctx; - ggml_backend_sched_reset(sched); // clear the allocation of the previous graph - ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it + if (cparams.warmup) { + mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + } else if (kvd->forced_sinfos && !kvd->forced_sinfos->empty()) { + LLAMA_LOG_DEBUG("%s: Forcing sinfos, bypassing find_slot.\n", __func__); + mctx = static_cast(memory.get())->init_batch_with_sinfos( + *balloc, cparams.n_ubatch, *kvd->forced_sinfos, true + ); + } else { + mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); - ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input"); - ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors + if (batch_inp.mtp_params.op_type == MTP_OP_NONE) { + if (mctx && mctx->get_status() == LLAMA_MEMORY_STATUS_SUCCESS) { + kvd->last_main_model_sinfos = static_cast(mctx.get())->get_sinfos(); + } else { + kvd->last_main_model_sinfos.clear(); + } + } + } + + return mctx; +} - ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embedding_input"); - ggml_backend_tensor_set(mtp_prev_embedding_input, last_embd, 0, ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors - ggml_backend_sched_graph_compute(sched, gf); // execute the graph +bool llama_context::prepare_mtp_graph_inputs( + llm_graph_result * res, + const llama_ubatch & ubatch, + const llama_mtp_params & mtp_params) { + + const char * target_tensor_name = "result_embd_pooled"; + ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name); - struct ggml_tensor * logits_mtp = res_mtp->get_logits(); + const float * source_hidden_state = nullptr; + if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { + source_hidden_state = this->embd; + } else { // MTP_OP_DRAFT_GEN + source_hidden_state = this->draft_input_hidden_state; + } - if (logits_mtp) { - float * logits_dest = ctx->get_logits_ith(last_tok_idx); - ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched, logits_mtp); - if (backend_res) { - // ggml_backend_tensor_get is the function for GPU->CPU copies. - // We are copying a single 32-bit integer. - ggml_backend_tensor_get(logits_mtp, - logits_dest, // Pointer to our C++ variable - 0, // Starting offset in bytes - ggml_nbytes(logits_mtp)); // Number of bytes to copy - } else { - LLAMA_LOG_ERROR("%s: ERROR - Could not obtain the backend for the logits tensor.", __func__); + if (source_hidden_state != nullptr && hidden_states_input != nullptr) { + const char * op_type; + if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { + op_type = "MTP_UPDATE"; + } else { // MTP_OP_DRAFT_GEN + op_type = "DRAFT_GEN"; } + + ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input)); } else { - LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__); + LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n", + __func__, target_tensor_name); + return false; } - ggml_backend_sched_free(sched); -} \ No newline at end of file + return true; +} diff --git a/src/llama-context.h b/src/llama-context.h index e8ea3a4c9be39..4d77d5d81aef1 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -20,6 +20,8 @@ class llama_io_write_i; struct llama_memory_i; struct llama_memory_context_i; +struct llama_context_kv_cache_data; + struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs llama_context( @@ -27,6 +29,15 @@ struct llama_context { llama_context_params params); ~llama_context(); + + // The llama_context manages significant resources (GPU memory, file handles, PImpl data) + // and is fundamentally a non-copyable, non-movable object. Deleting these special + // member functions enforces this rule and is also technically required to allow the + // PImpl pattern (via unique_ptr or void*) with an incomplete type in the header. + llama_context(const llama_context &) = delete; + llama_context & operator=(const llama_context &) = delete; + llama_context(llama_context &&) = delete; + llama_context & operator=(llama_context &&) = delete; void synchronize(); @@ -61,6 +72,8 @@ struct llama_context { float * get_embeddings_seq(llama_seq_id seq_id); ggml_tensor * get_embeddings_tensor(); + const float * draft_input_hidden_state = nullptr; + void attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch); @@ -91,6 +104,8 @@ struct llama_context { int32_t il_start, int32_t il_end); + void kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1); + // process a single ubatch with a specific graph type // if memory_context is provided, it will be applied first to the context's memory // ret contains the status of the graph computation @@ -99,7 +114,8 @@ struct llama_context { const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, - ggml_status & ret); + ggml_status & ret, + const llama_mtp_params & mtp_params); int encode(const llama_batch & batch_inp); int decode(const llama_batch & batch_inp); @@ -200,23 +216,33 @@ struct llama_context { // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); - llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx); - void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i); ggml_backend_sched_t create_temp_scheduler(size_t n_nodes); std::unique_ptr mtp_memory_batch(const llama_batch& batch_inp); + // For MTP KV cache cell reuse + void * kv_cache_data; + private: llm_graph_params graph_params( llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx, - llm_graph_type gtype) const; + llm_graph_type gtype, + const llama_mtp_params & mtp_params) const; llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const; + // Methods for MTP decode + std::unique_ptr initialize_decode_context(const llama_batch & batch_inp, const bool output_all); + + bool prepare_mtp_graph_inputs( + llm_graph_result * res, + const llama_ubatch & ubatch, + const llama_mtp_params & mtp_params); + // TODO: read/write lora adapters and cvec size_t state_write_data(llama_io_write_i & io); size_t state_read_data (llama_io_read_i & io); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 053c72d6dc8d1..be7de40454e80 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1074,6 +1074,26 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { return cur; } + +ggml_tensor * llm_graph_context::build_inp_embd_mtp(ggml_tensor * mtp_tok_embd) const { + auto inp = std::make_unique(); + ggml_tensor * cur = nullptr; + + if (ubatch.token) { + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + ggml_set_name(inp->tokens, "mtp_inp_tokens"); + ggml_set_input(inp->tokens); + + cur = ggml_get_rows(ctx0, mtp_tok_embd, inp->tokens); + } else { + GGML_ABORT("fatal error: MTP update expects token IDs, not embeddings"); + } + + cb(cur, "mtp_inp_embd", -1); + res->add_input(std::move(inp)); + return cur; +} + ggml_tensor * llm_graph_context::build_inp_pos() const { auto inp = std::make_unique(hparams.n_pos_per_embd()); diff --git a/src/llama-graph.h b/src/llama-graph.h index 10702ed219c01..3c5feadfdc733 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -29,6 +29,7 @@ enum llm_graph_type { LLM_GRAPH_TYPE_DEFAULT, LLM_GRAPH_TYPE_ENCODER, LLM_GRAPH_TYPE_DECODER, + LLM_GRAPH_TYPE_DRAFT, }; enum llm_ffn_op_type { @@ -94,6 +95,20 @@ class llm_graph_input_i { using llm_graph_input_ptr = std::unique_ptr; +class llm_graph_input_mtp_states : public llm_graph_input_i { +public: + llm_graph_input_mtp_states() = default; + virtual ~llm_graph_input_mtp_states() = default; + + void set_input(const llama_ubatch * /*ubatch*/) override {} + + bool can_reuse(const llm_graph_params & /*params*/) override { + return true; + } + + ggml_tensor * states = nullptr; +}; + class llm_graph_input_embd : public llm_graph_input_i { public: llm_graph_input_embd() = default; @@ -402,6 +417,7 @@ struct llm_graph_params { const llama_adapter_loras * loras; const llama_memory_context_i * mctx; const llama_cross * cross; + llama_mtp_params mtp_params; uint32_t n_outputs; @@ -450,6 +466,7 @@ struct llm_graph_params { cvec == other.cvec && loras == other.loras && cross == other.cross && + mtp_params.op_type == other.mtp_params.op_type && n_outputs == other.n_outputs; } }; @@ -664,6 +681,7 @@ struct llm_graph_context { // ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const; + ggml_tensor * build_inp_embd_mtp(ggml_tensor * mtp_tok_embd) const; ggml_tensor * build_inp_pos() const; ggml_tensor * build_inp_attn_scale() const; ggml_tensor * build_inp_out_ids() const; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 53466264cd9a7..8d9b1f631f7b6 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -508,6 +508,34 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch( return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } +llama_memory_context_ptr llama_kv_cache_unified::init_batch_with_sinfos( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + const slot_info_vec_t & sinfos, + bool is_inplace_update) { + + if (sinfos.empty()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + balloc.split_reset(); + std::vector ubatches; + while (true) { + auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true); + if (ubatch.n_tokens == 0) { + break; + } + ubatches.push_back(std::move(ubatch)); + } + + if (ubatches.size() != sinfos.size()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique( + this, sinfos, std::move(ubatches), is_inplace_update); +} + llama_memory_context_ptr llama_kv_cache_unified::init_full() { return std::make_unique(this); } @@ -928,64 +956,81 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ } assert(res.s1 >= res.s0); + if (!res.empty()) { + std::string idxs_str; + for (const auto& vec : res.idxs) { + if (!vec.empty()) { + if (vec.size() > 8) { + idxs_str += " [" + std::to_string(vec.front()) + "..." + std::to_string(vec.back()) + " (" + std::to_string(vec.size()) + " cells)]"; + } else { + idxs_str += " ["; + for(size_t i = 0; i < vec.size(); ++i) { + idxs_str += std::to_string(vec[i]) + (i == vec.size() - 1 ? "" : ", "); + } + idxs_str += "]"; + } + } + } + } return res; } -void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) { - // keep track of the max sequence position that we would overwrite with this ubatch - // for non-SWA cache, this would be always empty - llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; - for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { - seq_pos_max_rm[s] = -1; - } - - assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size()); - - for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { - for (uint32_t ii = 0; ii < sinfo.size(); ++ii) { - const uint32_t i = s*sinfo.size() + ii; - - auto & cells = v_cells[sinfo.strm[s]]; +void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, bool is_inplace_update) { + // For "in-place" updates (MTP warmup/accept), we only update the tensor data. + // The cell metadata (logical position, sequence ID) has already been set + // by the main model's pass. We must skip all metadata modifications + // to prevent `pos_set` from asserting on an already-set cell. + if (!is_inplace_update) { + // keep track of the max sequence position that we would overwrite with this ubatch + // for non-SWA cache, this would be always empty + llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + seq_pos_max_rm[s] = -1; + } - const auto idx = sinfo.idxs[s][ii]; + assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size()); - if (!cells.is_empty(idx)) { - assert(cells.seq_count(idx) == 1); + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + for (uint32_t ii = 0; ii < sinfo.size(); ++ii) { + const uint32_t i = s*sinfo.size() + ii; - const llama_seq_id seq_id = cells.seq_get(idx); - const llama_pos pos = cells.pos_get(idx); + auto & cells = v_cells[sinfo.strm[s]]; - seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + const auto idx = sinfo.idxs[s][ii]; - cells.rm(idx); - } + if (!cells.is_empty(idx)) { + assert(cells.seq_count(idx) == 1); + const llama_seq_id seq_id = cells.seq_get(idx); + const llama_pos pos = cells.pos_get(idx); + seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + cells.rm(idx); + } - cells.pos_set(idx, ubatch.pos[i]); + cells.pos_set(idx, ubatch.pos[i]); - for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { - cells.seq_add(idx, ubatch.seq_id[i][s]); + for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { + cells.seq_add(idx, ubatch.seq_id[i][s]); + } } } - } - // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence - // will be present in the cache. so we have to purge any position which is less than those we would overwrite - // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 - for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { - if (seq_pos_max_rm[s] == -1) { - continue; - } + // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence + // will be present in the cache. so we have to purge any position which is less than those we would overwrite + // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (seq_pos_max_rm[s] == -1) { + continue; + } - GGML_ASSERT(s < seq_to_stream.size()); + GGML_ASSERT(s < seq_to_stream.size()); - auto & cells = v_cells[seq_to_stream[s]]; + auto & cells = v_cells[seq_to_stream[s]]; - if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { - LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", - __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s); + if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { - seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); + seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); + } } } @@ -2290,7 +2335,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv, llama_kv_cache_unified::slot_info_vec_t sinfos, - std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) { + std::vector ubatches, + bool is_inplace_update) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)), is_inplace_update(is_inplace_update) { } llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default; @@ -2315,7 +2361,7 @@ bool llama_kv_cache_unified_context::apply() { return true; } - kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]); + kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur], is_inplace_update); n_kv = kv->get_n_kv(); diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index c02607c2d0f38..f64f7faa5c062 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -116,6 +116,12 @@ class llama_kv_cache_unified : public llama_memory_i { llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) override; + + llama_memory_context_ptr init_batch_with_sinfos( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + const slot_info_vec_t & sinfos, + bool is_inplace_update); llama_memory_context_ptr init_full() override; @@ -181,7 +187,7 @@ class llama_kv_cache_unified : public llama_memory_i { slot_info find_slot(const llama_ubatch & ubatch, bool cont) const; // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]] - void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch); + void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, bool is_inplace_update = false); // // input API @@ -321,7 +327,8 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { llama_kv_cache_unified_context( llama_kv_cache_unified * kv, slot_info_vec_t sinfos, - std::vector ubatches); + std::vector ubatches, + bool is_inplace_update = false); virtual ~llama_kv_cache_unified_context(); @@ -365,6 +372,8 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { void set_sinfos(slot_info_vec_t new_sinfos); + const slot_info_vec_t & get_sinfos() const { return sinfos; } + private: llama_memory_status status; @@ -399,4 +408,6 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { // a heuristic, to avoid attending the full cache if it is not yet utilized // as the cache gets filled, the benefit from this heuristic disappears int32_t n_kv; + + bool is_inplace_update = false; }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index dd4bf211b7e94..ab7daee356ae9 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13789,236 +13789,219 @@ struct llm_build_glm4 : public llm_graph_context { struct llm_build_glm4_moe : public llm_graph_context { llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); ggml_tensor * cur; - ggml_tensor * inpL; - inpL = build_inp_embd(model.tok_embd); + if (params.mtp_params.op_type != MTP_OP_NONE) { + ggml_tensor* hidden_states_from_main_model; - // inp_pos - contains the positions - ggml_tensor * inp_pos = build_inp_pos(); - - auto * inp_attn = build_attn_inp_kv_unified(); + if (params.mtp_params.op_type == MTP_OP_WARMUP || params.mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { + hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_name(hidden_states_from_main_model, "result_embd_pooled"); + ggml_set_input(hidden_states_from_main_model); - ggml_tensor * inp_out_ids = build_inp_out_ids(); + auto inp_mtp = std::make_unique(); + inp_mtp->states = hidden_states_from_main_model; + res->add_input(std::move(inp_mtp)); + } else { + hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd); + ggml_set_name(hidden_states_from_main_model, "result_embd_pooled"); + ggml_set_input(hidden_states_from_main_model); - // Only process up to last layer (skip final NextN layer) - // Final layer tensors are loaded but not processed in forward pass - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { - ggml_tensor * inpSA = inpL; + auto inp_mtp = std::make_unique(); + inp_mtp->states = hidden_states_from_main_model; + res->add_input(std::move(inp_mtp)); + } - // Pre-attention norm - cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); + const int il_mtp = hparams.n_layer - 1; + const auto & mtp_layer = model.layers[il_mtp]; + res->t_logits = build_mtp_tail(mtp_layer, hidden_states_from_main_model, n_embd_head); - // self-attention - { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - cb(Qcur, "Qcur", il); + } else { + ggml_tensor * inpL = build_inp_embd(model.tok_embd); + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_unified(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + // Only process up to last layer (skip final NextN layer) + // Final layer tensors are loaded but not processed in forward pass + const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { + ggml_tensor * inpSA = inpL; + + // Pre-attention norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - cb(Kcur, "Kcur", il); + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + cb(Qcur, "Qcur", il); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - cb(Vcur, "Vcur", il); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + cb(Kcur, "Kcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + cb(Vcur, "Vcur", il); - // Apply Q/K norm if available (GLM-4.5 355B variant) - if (model.layers[il].attn_q_norm) { - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); - cb(Qcur, "Qcur_normed", il); - } - if (model.layers[il].attn_k_norm) { - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); - cb(Kcur, "Kcur_normed", il); - } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + // Apply Q/K norm if available (GLM-4.5 355B variant) + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + } + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + } - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); - cur = build_attn(inp_attn, - model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); - } + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); - if (il == n_transformer_layers - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); + if (il == n_transformer_layers - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } - // Post-attention norm - cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "post_attn_norm", il); + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); - // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) - if (static_cast(il) < hparams.n_layer_dense_lead) { - // Dense FFN layer - cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); - } else { - // Process routed experts using existing MoE infrastructure - ggml_tensor * routed_out = build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - model.layers[il].ffn_exp_probs_b, - n_expert, n_expert_used, - LLM_FFN_SILU, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, - (llama_expert_gating_func_type) hparams.expert_gating_func, - il); - cb(routed_out, "ffn_moe_out", il); + // Post-attention norm + cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) + if (static_cast(il) < hparams.n_layer_dense_lead) { + // Dense FFN layer + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // Process routed experts using existing MoE infrastructure + ggml_tensor * routed_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(routed_out, "ffn_moe_out", il); - // Process shared expert on original input - ggml_tensor * shared_out = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, - model.layers[il].ffn_gate_shexp, NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(shared_out, "ffn_shexp_out", il); + // Process shared expert on original input + ggml_tensor * shared_out = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(shared_out, "ffn_shexp_out", il); - // Final output: routed_output + shared_output - cur = ggml_add(ctx0, routed_out, shared_out); - cb(cur, "ffn_out", il); - } + // Final output: routed_output + shared_output + cur = ggml_add(ctx0, routed_out, shared_out); + cb(cur, "ffn_out", il); + } - cur = ggml_add(ctx0, cur, ffn_inp); + cur = ggml_add(ctx0, cur, ffn_inp); - cur = build_cvec(cur, il); - cb(cur, "l_out", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); - // input for next layer - inpL = cur; - } + // input for next layer + inpL = cur; + } - cur = inpL; - cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + cur = inpL; + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); - cb(cur, "result_norm", -1); - res->t_embd = cur; + // cb(cur, "result_norm", -1); + res->t_embd = cur; - // lm_head - cur = build_lora_mm(model.output, cur); + // Use the main model header + res->t_logits = build_lora_mm(model.output, cur); + } - cb(cur, "result_output", -1); - res->t_logits = cur; + ggml_build_forward_expand(gf, res->t_logits); - ggml_build_forward_expand(gf, cur); } -}; - -struct llm_build_glm4_moe_mtp : public llm_graph_context { - llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params, - // For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization - llama_token last_token_id, int n_past - ) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); +private: + ggml_tensor * build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings, + int64_t n_embd_head + ) { + ggml_tensor * embd_copy = ggml_dup(ctx0, prev_embeddings); - // Assuming a single MTP layer at the end const int il = hparams.n_layer - 1; - const auto & mtp_layer = model.layers[il]; + ggml_tensor * sum_node = ggml_sum(ctx0, embd_copy); - // ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); - // ggml_set_i32(inp_pos, n_past); - ggml_tensor * inp_pos = build_inp_pos(); + ggml_set_name(sum_node, "mtp_input_sum"); - //llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr; + ggml_tensor * inp_pos = build_inp_pos(); auto * inp_attn = build_attn_inp_kv_unified(); + ggml_tensor * token_emb = build_inp_embd_mtp(mtp_layer.nextn.embed_tokens); - // get MTP embedding for last (conventionally sampled) token - // ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); - // LLAMA_LOG_INFO("step: '%d'\n", 5641); - // ggml_set_i32(inp_token_id, last_token_id); - //ggml_set_no_alloc(ctx0, false); - //LLAMA_LOG_INFO("last token id: '%d'\n", last_token_id); - - ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); - ggml_set_name(inp_token_id, "mtp_token_id_input"); - ggml_set_input(inp_token_id); - - //ggml_tensor * inp_token_id = ggml_new_i32(ctx0, last_token_id); - //ggml_set_no_alloc(ctx0, true); - - ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id); ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); - - ggml_tensor* prev_embedding_leaf = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, model.hparams.n_embd); - ggml_set_name(prev_embedding_leaf, "mtp_prev_embedding_input"); - ggml_set_input(prev_embedding_leaf); - - // vLLM l99 previous_hidden_states = self.hnorm(previous_hidden_states) - ggml_tensor * hidden_state_norm = build_norm(prev_embedding_leaf, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); - //token_emb_norm = ggml_cont(ctx0, token_emb_norm); - //hidden_state_norm = ggml_cont(ctx0, hidden_state_norm); - - ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); // torch.cat - - ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj + ggml_tensor * hidden_state_norm = build_norm(embd_copy, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); + + ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); + ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // now proceed through last layer (skipped in main model) ggml_tensor * inpSA = cur; - // Pre-attention norm for the MTP block - ggml_tensor* attn_inp = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il); + cur = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il); // self-attention { ggml_tensor * Qcur = build_lora_mm(mtp_layer.wq, cur); - if (mtp_layer.bq) { - Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq); - } + if (mtp_layer.bq) Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq); cb(Qcur, "Qcur", il); ggml_tensor * Kcur = build_lora_mm(mtp_layer.wk, cur); - if (mtp_layer.bk) { - Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk); - } + if (mtp_layer.bk) Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk); cb(Kcur, "Kcur", il); ggml_tensor * Vcur = build_lora_mm(mtp_layer.wv, cur); - if (mtp_layer.bv) { - Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv); - } + if (mtp_layer.bv) Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv); cb(Vcur, "Vcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); @@ -14052,8 +14035,8 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - mtp_layer.wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + mtp_layer.wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); @@ -14090,12 +14073,10 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { cb(cur, "ffn_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); - cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il); cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur); - - res->t_logits = cur; - ggml_build_forward_expand(gf, res->t_logits); + + return cur; } }; @@ -18324,8 +18305,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { + std::unique_ptr llm; - switch (arch) { case LLM_ARCH_LLAMA: { @@ -18683,25 +18664,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { GGML_ABORT("fatal error"); } - // add on pooling layer - llm->build_pooling(cls, cls_b, cls_out, cls_out_b); - - return llm->res->get_gf(); -} - -ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params, - llama_token last_token_id, int n_past) const { - std::unique_ptr llm; - - switch (arch) { - case LLM_ARCH_GLM4_MOE: - { - llm = std::make_unique(*this, params, last_token_id, n_past); - } break; - default: - GGML_ABORT("fatal error"); + if (params.mtp_params.op_type == MTP_OP_NONE) { + // add on pooling layer + llm->build_pooling(cls, cls_b, cls_out, cls_out_b); } - return llm->res->get_gf(); } diff --git a/src/llama-model.h b/src/llama-model.h index b28a37488f78a..6fcd74d57fdca 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -475,8 +475,6 @@ struct llama_model { // TODO: move this to new llm_arch_model_i interface ggml_cgraph * build_graph(const llm_graph_params & params) const; - ggml_cgraph * build_mtp_graph(const llm_graph_params & params, - llama_token last_token_id, int n_past) const; private: struct impl; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 34053cd040388..a24532c6939af 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1296,7 +1296,6 @@ struct server_slot { common_speculative * spec = nullptr; bool has_mtp = false; - std::vector mtp_kv_update_batch; int32_t last_tok_idx = -1; std::vector lora; @@ -3394,9 +3393,6 @@ struct server_context { // embedding requires all tokens in the batch to be output const bool need_embd = server_task_type_need_embd(slot.task_type); - if (slot.has_mtp) { - slot.mtp_kv_update_batch.push_back({ cur_tok, slot.n_past, batch.n_tokens }); - } common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); slot.cache_tokens.push_back(cur_tok); @@ -3513,11 +3509,19 @@ struct server_context { continue; // continue loop of n_batch } - for (auto & slot : slots) { - // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation - if (slot.has_mtp) { - mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, i, n_tokens); - } + if (slot_batched && slot_batched->has_mtp && + (slot_batched->state == SLOT_STATE_PROCESSING_PROMPT || slot_batched->state == SLOT_STATE_DONE_PROMPT)) { + + // Prepare the context to reuse the exact sinfo layout (including multiple u-batches) + // from the main model's prompt processing pass. This ensures the MTP layer's + // KV cache is perfectly aligned. + if (llama_mtp_prepare_sinfo_for_warmup(ctx)) { + mtp_update_kv_cache(ctx, batch_view, true); + // Clean up the forced state to not affect subsequent decodes. + llama_mtp_cancel_sinfo_update(ctx); + } else { + LOG_ERR("%s: Failed to prepare the MTP for warmup.", __func__); + } } // move the head of the batch forward with the number of tokens we just processed @@ -3554,20 +3558,16 @@ struct server_context { } const int tok_idx = slot.i_batch - i; - + // Sets the initial state for the first draft generation. + if (slot.has_mtp) { + llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, -1)); + } llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); slot.last_tok_idx = tok_idx; - //SRV_INF("main loop sampled token: '%s'\n", common_token_to_piece(ctx, id, true).c_str()); slot.i_batch = -1; - common_sampler_accept(slot.smpl, id, true); - // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation - //if (slot.has_mtp) { - // mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); - //} - slot.n_decoded += 1; const int64_t t_current = ggml_time_us(); @@ -3652,11 +3652,6 @@ struct server_context { draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); } - //llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); - //llama_tokens draft; - //draft.reserve(1); - //draft.push_back(draft_id); - // ignore small drafts if (slot.params.speculative.n_min > (int)draft.size()) { SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min); @@ -3677,17 +3672,21 @@ struct server_context { } SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); - llama_decode(ctx, slot.batch_spec); // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); - + if (slot.has_mtp) { - for (int32_t i = 0; i < ids.size(); ++i) { - slot.mtp_kv_update_batch.push_back({ ids[i], slot.n_past + 1 + i, i }); + llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, ids.size() - 1)); + + if (!ids.empty()) { + llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, ids.size() - 1)); + } else { + llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, 0)); } - mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); + + mtp_accept_tokens(ctx, ids, slot.n_past, slot.id); } slot.n_past += ids.size();