diff --git a/common/speculative.cpp b/common/speculative.cpp index a7a4042682184..fe63bb30eb73a 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -363,33 +363,37 @@ llama_tokens common_speculative_gen_draft( } -llama_token mtp_speculative_gen_draft( +llama_token mtp_update_and_draft( struct common_sampler* smpl, struct llama_context* ctx, - llama_token id_last, - int32_t n_past, - int32_t last_tok_idx) { + const llama_tokens & ids_accepted, + int32_t n_past_base, + llama_token id_base +) { + if (!smpl) { return -1; } - if (!smpl) { - return -1; - } - llama_batch mtp_batch = llama_batch_init(1, 0, 1); - const llama_pos draft_pos = n_past; - const llama_seq_id draft_seq_id = 0; - common_batch_add(mtp_batch, id_last, n_past, {0}, true); + const int n_update = ids_accepted.size(); + const int n_total = n_update + 1; - mtp_batch.mtp_params.op_type = MTP_OP_DRAFT_GEN; + const int32_t n_past_update_start = n_past_base - n_update; + + llama_batch batch = llama_batch_init(n_update + 1, 0, 1); + for (int i = 0; i < n_update; ++i) { + common_batch_add(batch, ids_accepted[i], n_past_update_start + i, { 0 }, true); + } + common_batch_add(batch, id_base, n_past_base, { 0 }, true); + batch.mtp_params.op_type = MTP_OP_UPDATE_AND_DRAFT; - // Perform the MTP draft generation decode. This writes the MTP layer's - // KV state for the draft token into the cache. - llama_decode(ctx, mtp_batch); - llama_batch_free(mtp_batch); + if (!llama_mtp_prepare_sinfo_for_update_and_draft(ctx, n_update)) { + LOG_ERR("[MTP-FLOW] Failed to prepare hybrid sinfo. Aborting.\n"); + llama_batch_free(batch); + return -1; + } - // CRITICAL: Purge the metadata for the draft token we just wrote. - // This makes the physical cell available again for the main model's validation pass, - // preventing a cache state corruption where two cells map to the same logical position. - llama_kv_cache_seq_rm(ctx, draft_seq_id, draft_pos, draft_pos + 1); + llama_decode(ctx, batch); + llama_mtp_cancel_sinfo_update(ctx); + const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); const int n_vocab = llama_n_vocab(vocab); @@ -397,12 +401,22 @@ llama_token mtp_speculative_gen_draft( cur_p->size = n_vocab; for (int i = 0; i < n_vocab; ++i) { cur_p->data[i].id = i; - cur_p->data[i].logit = llama_get_logits_ith(ctx, 0)[i]; // For a single-token batch, logits are always at index 0. + cur_p->data[i].logit = llama_get_logits_ith(ctx, n_total - 1)[i]; } cur_p->sorted = false; common_sampler_apply_chain(smpl, cur_p); - return cur_p->data[0].id; + const llama_token draft_id = cur_p->data[0].id; + llama_batch_free(batch); + + // CRITICAL: Purge the metadata for the draft token we just wrote. + // This makes the physical cell available again for the main model's validation pass, + // preventing a cache state corruption where two cells map to the same logical position. + const llama_pos draft_pos = n_past_base; + LOG_INF("tokens being deleted: %d e %d\n", draft_pos, draft_pos + 1); + llama_kv_cache_seq_rm(ctx, 0, draft_pos, draft_pos + 1); + + return draft_id; } @@ -423,35 +437,8 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b for (int i = 0; i < mtp_batch.n_tokens; ++i) { mtp_batch.logits[i] = true; } + const int64_t t_start_us = ggml_time_us(); llama_decode(ctx, mtp_batch); -} - -void mtp_accept_tokens( - struct llama_context * ctx, - const std::vector & ids, - int32_t n_past_base, - llama_seq_id seq_id -) { - if (ids.empty()) { - return; - } - - // Prepare a resized copy of the validation sinfo to match the number of accepted tokens. - // This sets up the context for a "forced sinfo" decode. - if (!llama_mtp_prepare_sinfo_for_update(ctx, ids.size())) { - return; - } - - // Build a new batch containing only the accepted tokens. - llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1); - for (size_t i = 0; i < ids.size(); ++i) { - common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true); - } - - mtp_update_kv_cache(ctx, accepted_batch, false); - - // Clean up the forced state to not affect subsequent, normal decode calls. - llama_mtp_cancel_sinfo_update(ctx); - - llama_batch_free(accepted_batch); -} + const int64_t t_end_us = ggml_time_us(); + LOG_INF("[PERF-MTP] mtp_update_kv_cache internal decode (op=%d): %.2f ms\n", (int)mtp_batch.mtp_params.op_type, (t_end_us - t_start_us) / 1000.0); +} \ No newline at end of file diff --git a/common/speculative.h b/common/speculative.h index 8b81f4ac77df3..93f079ee85c7c 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -35,12 +35,13 @@ void common_speculative_add_replacement_tgt_dft( // sample up to n_draft tokens and add them to the batch using the draft model -llama_token mtp_speculative_gen_draft( +llama_token mtp_update_and_draft( struct common_sampler* smpl, struct llama_context* ctx, - llama_token id_last, - int32_t n_past, - int32_t last_tok_idx); + const llama_tokens & ids_accepted, + int32_t n_past_base, + llama_token id_base +); // sample up to n_draft tokens and add them to the batch using the draft model llama_tokens common_speculative_gen_draft( @@ -49,11 +50,4 @@ llama_tokens common_speculative_gen_draft( const llama_tokens & prompt, llama_token id_last); -void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup); - -void mtp_accept_tokens( - struct llama_context * ctx, - const std::vector & ids, - int32_t n_past_base, - llama_seq_id seq_id -); +void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup); \ No newline at end of file diff --git a/include/llama.h b/include/llama.h index 0b15d4bf1cd0d..02737444db26d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -225,7 +225,7 @@ extern "C" { MTP_OP_NONE, MTP_OP_WARMUP, MTP_OP_UPDATE_ACCEPTED, - MTP_OP_DRAFT_GEN, + MTP_OP_UPDATE_AND_DRAFT } llama_mtp_op_type; typedef struct llama_mtp_params { @@ -1471,15 +1471,15 @@ extern "C" { // LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state); - + /** * @brief Prepares the context for an MTP KV cache update by creating a resized copy of the last sinfo. * This is used after speculative validation when only a subset of draft tokens are accepted. * @param n_accepted The number of tokens that were accepted and for which the sinfo should be resized. * @return true on success. */ - LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted); - + LLAMA_API bool llama_mtp_prepare_sinfo_for_update_and_draft(struct llama_context * ctx, size_t n_update); + /** * @brief Prepares the context for an MTP KV cache update by reusing the sinfo from the last main model decode. * This is used for the prompt warmup to ensure the MTP and main model KV caches are perfectly aligned. diff --git a/src/llama-context.cpp b/src/llama-context.cpp index fb35d6c79debf..193aa734acb2e 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -17,10 +17,25 @@ // // llama_context // +// Key for the graph cache. It contains all parameters that define the graph topology. +struct llama_graph_cache_key { + uint32_t n_tokens; + uint32_t n_outputs; + llama_mtp_op_type op_type; + bool causal_attn; + + bool operator<(const llama_graph_cache_key& other) const { + return std::tie(n_tokens, n_outputs, op_type, causal_attn) < + std::tie(other.n_tokens, other.n_outputs, other.op_type, other.causal_attn); + } +}; + struct llama_context_kv_cache_data { llama_kv_cache_unified::slot_info_vec_t last_main_model_sinfos; llama_kv_cache_unified::slot_info_vec_t resized_sinfo_for_force; const llama_kv_cache_unified::slot_info_vec_t * forced_sinfos = nullptr; + std::map graph_cache; + llm_graph_result_ptr gf_res_prev_validation; }; llama_context::llama_context( @@ -745,43 +760,113 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll return nullptr; } - auto * res = gf_res_prev.get(); - auto * gf = res->get_gf(); - - // the new graph parameters - // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters - const auto gparams = graph_params(res, ubatch, mctx, gtype, mtp_params); - - if (!graph_reuse_disable && res->can_reuse(gparams)) { + auto * kvd = static_cast(kv_cache_data); + llm_graph_result * res; + + if (mtp_params.op_type != MTP_OP_NONE) { + int32_t n_outputs = 0; + for (int i = 0; i < ubatch.n_tokens; ++i) { if (ubatch.output[i]) n_outputs++; } + const llama_graph_cache_key key = { ubatch.n_tokens, (uint32_t)n_outputs, mtp_params.op_type, cparams.causal_attn }; + + auto & res_ptr = kvd->graph_cache[key]; + if (!res_ptr) { + LLAMA_LOG_INFO("[GRAPH-CACHE] Creating a new graph container for key (op=%d, tok=%d, out=%d)\n", + (int)key.op_type, key.n_tokens, key.n_outputs); + res_ptr = std::make_unique(graph_max_nodes()); + } + res = res_ptr.get(); + + // the new graph parameters + // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters + const auto gparams = graph_params(res, ubatch, mctx, gtype, mtp_params); + + // if (!graph_reuse_disable && res->can_reuse(gparams)) { //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); - - n_reused++; - } else { - res->reset(); - + // LLAMA_LOG_INFO("[GRAPH-CACHE] HIT, reusing graph STRUCTURE for key (op=%d, tok=%d, out=%d)\n", + // (int)key.op_type, key.n_tokens, key.n_outputs); + // n_reused++; + // } else { + LLAMA_LOG_INFO("[GRAPH-CACHE] MISS, RECONSTRUCTING THE STRUCTURE of the graph for key (op=%d, tok=%d, out=%d)\n", + (int)key.op_type, key.n_tokens, key.n_outputs); + + const int64_t t_reset_start_us = ggml_time_us(); ggml_backend_sched_reset(sched.get()); + const int64_t t_reset_end_us = ggml_time_us(); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); - //const auto t_start_us = ggml_time_us(); - - gf = model.build_graph(gparams); - - //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); + res->reset(); + res->set_params(gparams); + const int64_t t_build_start_us = ggml_time_us(); + res->gf = model.build_graph(gparams); + const int64_t t_build_end_us = ggml_time_us(); + LLAMA_LOG_INFO("[PERF-GRAPH] Graph build (op=%d): %.2f ms\n", (int)mtp_params.op_type, (t_build_end_us - t_build_start_us) / 1000.0); - if (!gf) { + if (!res->gf) { LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); ret = GGML_STATUS_FAILED; return nullptr; } - if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { + const int64_t t_alloc_start_us = ggml_time_us(); + if (!ggml_backend_sched_alloc_graph(sched.get(), res->gf)) { LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); ret = GGML_STATUS_ALLOC_FAILED; return nullptr; } + const int64_t t_alloc_end_us = ggml_time_us(); + LLAMA_LOG_INFO("[PERF-GRAPH] sched_reset: %.2f ms | sched_alloc: %.2f ms (op=%d)\n", + (t_reset_end_us - t_reset_start_us) / 1000.0, + (t_alloc_end_us - t_alloc_start_us) / 1000.0, + (int)mtp_params.op_type); + // } + + } else { + res = gf_res_prev.get(); + const auto gparams = graph_params(res, ubatch, mctx, gtype, mtp_params); + + if (!graph_reuse_disable && res->can_reuse(gparams)) { + LLAMA_LOG_INFO("%s: reusing previous graph\n", __func__); + n_reused++; + } else { + LLAMA_LOG_INFO("%s: RECONSTRUCTED graph...\n", __func__); + + const int64_t t_reset_start_us = ggml_time_us(); + ggml_backend_sched_reset(sched.get()); + const int64_t t_reset_end_us = ggml_time_us(); + ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + + res->reset(); + res->set_params(gparams); + //const auto t_start_us = ggml_time_us(); + + const int64_t t_build_start_us = ggml_time_us(); + res->gf = model.build_graph(gparams); + const int64_t t_build_end_us = ggml_time_us(); + LLAMA_LOG_INFO("[PERF-GRAPH] Graph build (op=%d): %.2f ms\n", (int)mtp_params.op_type, (t_build_end_us - t_build_start_us) / 1000.0); + + //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); + + if (!res->gf) { + LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + const int64_t t_alloc_start_us = ggml_time_us(); + if (!ggml_backend_sched_alloc_graph(sched.get(), res->gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); + ret = GGML_STATUS_ALLOC_FAILED; + return nullptr; + } + const int64_t t_alloc_end_us = ggml_time_us(); + LLAMA_LOG_INFO("[PERF-GRAPH] sched_reset: %.2f ms | sched_alloc: %.2f ms (op=%d)\n", + (t_reset_end_us - t_reset_start_us) / 1000.0, + (t_alloc_end_us - t_alloc_start_us) / 1000.0, + (int)mtp_params.op_type); + } } - if (mtp_params.op_type != MTP_OP_NONE) { // If it is any MTP operation + if (mtp_params.op_type != MTP_OP_NONE) { if (!prepare_mtp_graph_inputs(res, ubatch, mtp_params)) { ret = GGML_STATUS_FAILED; return nullptr; @@ -1156,13 +1241,12 @@ int llama_context::decode(const llama_batch & batch_inp) { // extract logits if (t_logits && n_outputs > 0) { // MTP operations that are purely for updating the KV cache - // (MTP_OP_WARMUP and MTP_OP_UPDATE_ACCEPTED) also produce a logit tensor + // (MTP_OP_WARMUP) also produce a logit tensor // as a side effect of running the graph. If these logits are copied // back to the main context buffer, they will overwrite the valid logits // produced by the main model's pass, leading to incorrect sampling. // This condition explicitly prevents that copy for cache-only operations. - if (batch_inp.mtp_params.op_type != MTP_OP_WARMUP && - batch_inp.mtp_params.op_type != MTP_OP_UPDATE_ACCEPTED) { + if (batch_inp.mtp_params.op_type != MTP_OP_WARMUP) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits != nullptr); @@ -1174,6 +1258,8 @@ int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); } + } else { + LLAMA_LOG_DEBUG("%s: Skipping logit copy for MTP_OP_WARMUP.\n", __func__); } } @@ -3016,24 +3102,55 @@ bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx) { } -bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted) { - auto * kvd = static_cast(ctx->kv_cache_data); +bool llama_context::prepare_sinfo_for_update_and_draft(size_t n_update) { + auto * kvd = static_cast(this->kv_cache_data); const auto & last_sinfo = kvd->last_main_model_sinfos; - if (last_sinfo.empty() || last_sinfo[0].idxs.empty()) { - LLAMA_LOG_ERROR("%s: The sinfo for the last main call is not available.", __func__); + if (last_sinfo.empty() || (n_update > 0 && last_sinfo[0].idxs.empty())) { + LLAMA_LOG_ERROR("%s: The sinfo for the last main call is not available.\n", __func__); return false; } kvd->resized_sinfo_for_force = last_sinfo; - - kvd->resized_sinfo_for_force[0].idxs[0].resize(n_accepted); + if (kvd->resized_sinfo_for_force[0].idxs.empty()) { + kvd->resized_sinfo_for_force[0].idxs.push_back({}); + } + kvd->resized_sinfo_for_force[0].idxs[0].resize(n_update); + + if (this->memory) { + llama_ubatch ubatch_draft_fake = {}; + llama_seq_id seq_id_draft = 0; + ubatch_draft_fake.n_tokens = 1; + ubatch_draft_fake.n_seqs_unq = 1; + ubatch_draft_fake.seq_id_unq = &seq_id_draft; + int32_t n_seq_id_fake = 1; + llama_seq_id* seq_id_ptr_fake = &seq_id_draft; + ubatch_draft_fake.n_seq_id = &n_seq_id_fake; + ubatch_draft_fake.seq_id = &seq_id_ptr_fake; + + auto * memory_unified = static_cast(this->memory.get()); + auto sinfo_draft = memory_unified->find_slot(ubatch_draft_fake, true); + + if (sinfo_draft.empty()) { + LLAMA_LOG_ERROR("%s: Failed to find a cache slot for the draft token.\n", __func__); + return false; + } + + kvd->resized_sinfo_for_force[0].idxs[0].push_back(sinfo_draft.idxs[0][0]); + } else { + LLAMA_LOG_ERROR("%s: The context memory is not initialized.\n", __func__); + return false; + } kvd->forced_sinfos = &kvd->resized_sinfo_for_force; return true; } +bool llama_mtp_prepare_sinfo_for_update_and_draft(struct llama_context * ctx, size_t n_update) { + return ctx->prepare_sinfo_for_update_and_draft(n_update); +} + void llama_mtp_cancel_sinfo_update(struct llama_context * ctx) { auto * kvd = static_cast(ctx->kv_cache_data); kvd->forced_sinfos = nullptr; @@ -3065,8 +3182,17 @@ std::unique_ptr llama_context::initialize_decode_context mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); } else if (kvd->forced_sinfos && !kvd->forced_sinfos->empty()) { LLAMA_LOG_DEBUG("%s: Forcing sinfos, bypassing find_slot.\n", __func__); + + int n_inplace = 0; + if (batch_inp.mtp_params.op_type == MTP_OP_UPDATE_AND_DRAFT) { + n_inplace = batch_inp.n_tokens - 1; + } else if (batch_inp.mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED || + batch_inp.mtp_params.op_type == MTP_OP_WARMUP) { + n_inplace = batch_inp.n_tokens; + } + mctx = static_cast(memory.get())->init_batch_with_sinfos( - *balloc, cparams.n_ubatch, *kvd->forced_sinfos, true + *balloc, cparams.n_ubatch, *kvd->forced_sinfos, n_inplace ); } else { mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); @@ -3093,18 +3219,17 @@ bool llama_context::prepare_mtp_graph_inputs( ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name); const float * source_hidden_state = nullptr; - if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { + if (mtp_params.op_type == MTP_OP_WARMUP || + mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED || + mtp_params.op_type == MTP_OP_UPDATE_AND_DRAFT) { source_hidden_state = this->embd; - } else { // MTP_OP_DRAFT_GEN - source_hidden_state = this->draft_input_hidden_state; + // TODO: Simplify the logic } if (source_hidden_state != nullptr && hidden_states_input != nullptr) { const char * op_type; if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { op_type = "MTP_UPDATE"; - } else { // MTP_OP_DRAFT_GEN - op_type = "DRAFT_GEN"; } ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input)); diff --git a/src/llama-context.h b/src/llama-context.h index 4d77d5d81aef1..153eccb9bdbb8 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -225,6 +225,8 @@ struct llama_context { // For MTP KV cache cell reuse void * kv_cache_data; + bool prepare_sinfo_for_update_and_draft(size_t n_update); + private: llm_graph_params graph_params( llm_graph_result * res, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index be7de40454e80..50bd8e03f7408 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -442,34 +442,22 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) { bool llm_graph_result::can_reuse(const llm_graph_params & params) { if (!this->params.allow_reuse(params)) { - if (debug > 1) { - LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__); - } - + LLAMA_LOG_WARN("[GRAPH-REUSE-FAIL] Failure in 'allow_reuse'. Incompatible parameters."); + LLAMA_LOG_WARN(" n_tokens: %d vs %d, op_type: %d vs %d", + this->params.ubatch.n_tokens, params.ubatch.n_tokens, + (int)this->params.mtp_params.op_type, (int)params.mtp_params.op_type); return false; } - if (debug > 1) { - LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size()); - } - - bool res = true; - - for (auto & input : inputs) { - const bool cur = input->can_reuse(params); - - if (debug > 1) { - LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur); + for (size_t i = 0; i < inputs.size(); ++i) { + if (!inputs[i]->can_reuse(params)) { + LLAMA_LOG_WARN("[GRAPH-REUSE-FAIL] Failure in 'can_reuse' of the input node #%zu.", i); + return false; } - - res = res && cur; } - if (debug > 0) { - LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res); - } - - return res; + LLAMA_LOG_DEBUG("%s: can reuse graph = true\n", __func__); + return true; } llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) { diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 8d9b1f631f7b6..34efb4425a3cb 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -512,7 +512,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch_with_sinfos( llama_batch_allocr & balloc, uint32_t n_ubatch, const slot_info_vec_t & sinfos, - bool is_inplace_update) { + int n_inplace) { if (sinfos.empty()) { return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); @@ -533,7 +533,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch_with_sinfos( } return std::make_unique( - this, sinfos, std::move(ubatches), is_inplace_update); + this, sinfos, std::move(ubatches), n_inplace); } llama_memory_context_ptr llama_kv_cache_unified::init_full() { @@ -976,12 +976,12 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ return res; } -void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, bool is_inplace_update) { +void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, int n_inplace) { // For "in-place" updates (MTP warmup/accept), we only update the tensor data. // The cell metadata (logical position, sequence ID) has already been set // by the main model's pass. We must skip all metadata modifications // to prevent `pos_set` from asserting on an already-set cell. - if (!is_inplace_update) { + if (n_inplace < ubatch.n_tokens) { // keep track of the max sequence position that we would overwrite with this ubatch // for non-SWA cache, this would be always empty llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; @@ -992,7 +992,8 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size()); for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { - for (uint32_t ii = 0; ii < sinfo.size(); ++ii) { + for (uint32_t ii = n_inplace; ii < sinfo.size(); ++ii) { + // TODO: Check if n_inplace can work without MTP. const uint32_t i = s*sinfo.size() + ii; auto & cells = v_cells[sinfo.strm[s]]; @@ -2336,7 +2337,7 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv, llama_kv_cache_unified::slot_info_vec_t sinfos, std::vector ubatches, - bool is_inplace_update) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)), is_inplace_update(is_inplace_update) { + int n_inplace) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)), n_inplace(n_inplace) { } llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default; @@ -2361,7 +2362,7 @@ bool llama_kv_cache_unified_context::apply() { return true; } - kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur], is_inplace_update); + kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur], n_inplace); n_kv = kv->get_n_kv(); diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index f64f7faa5c062..5e874788de575 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -121,7 +121,7 @@ class llama_kv_cache_unified : public llama_memory_i { llama_batch_allocr & balloc, uint32_t n_ubatch, const slot_info_vec_t & sinfos, - bool is_inplace_update); + int n_inplace); llama_memory_context_ptr init_full() override; @@ -187,7 +187,7 @@ class llama_kv_cache_unified : public llama_memory_i { slot_info find_slot(const llama_ubatch & ubatch, bool cont) const; // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]] - void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, bool is_inplace_update = false); + void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, int n_inplace = 0); // // input API @@ -328,7 +328,7 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { llama_kv_cache_unified * kv, slot_info_vec_t sinfos, std::vector ubatches, - bool is_inplace_update = false); + int n_inplace = 0); virtual ~llama_kv_cache_unified_context(); @@ -409,5 +409,5 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { // as the cache gets filled, the benefit from this heuristic disappears int32_t n_kv; - bool is_inplace_update = false; + int n_inplace = 0; }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index ab7daee356ae9..b21d661c6fb24 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13796,24 +13796,15 @@ struct llm_build_glm4_moe : public llm_graph_context { if (params.mtp_params.op_type != MTP_OP_NONE) { ggml_tensor* hidden_states_from_main_model; - if (params.mtp_params.op_type == MTP_OP_WARMUP || params.mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { - hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); - ggml_set_name(hidden_states_from_main_model, "result_embd_pooled"); - ggml_set_input(hidden_states_from_main_model); - - auto inp_mtp = std::make_unique(); - inp_mtp->states = hidden_states_from_main_model; - res->add_input(std::move(inp_mtp)); - } else { - hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd); - ggml_set_name(hidden_states_from_main_model, "result_embd_pooled"); - ggml_set_input(hidden_states_from_main_model); + hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); - auto inp_mtp = std::make_unique(); - inp_mtp->states = hidden_states_from_main_model; - res->add_input(std::move(inp_mtp)); - } + ggml_set_name(hidden_states_from_main_model, "result_embd_pooled"); + ggml_set_input(hidden_states_from_main_model); + auto inp_mtp = std::make_unique(); + inp_mtp->states = hidden_states_from_main_model; + res->add_input(std::move(inp_mtp)); + const int il_mtp = hparams.n_layer - 1; const auto & mtp_layer = model.layers[il_mtp]; res->t_logits = build_mtp_tail(mtp_layer, hidden_states_from_main_model, n_embd_head); @@ -13971,8 +13962,9 @@ struct llm_build_glm4_moe : public llm_graph_context { ggml_tensor * embd_copy = ggml_dup(ctx0, prev_embeddings); const int il = hparams.n_layer - 1; + // cb(embd_copy, "mtp_embd_copy", il); ggml_tensor * sum_node = ggml_sum(ctx0, embd_copy); - + // cb(sum_node, "mtp_sum_node", il); ggml_set_name(sum_node, "mtp_input_sum"); ggml_tensor * inp_pos = build_inp_pos(); @@ -13983,30 +13975,48 @@ struct llm_build_glm4_moe : public llm_graph_context { ggml_tensor * hidden_state_norm = build_norm(embd_copy, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); + // cb(combined, "mtp_combined", il); + ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // now proceed through last layer (skipped in main model) ggml_tensor * inpSA = cur; // Pre-attention norm for the MTP block cur = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il); + // cb(cur, "mtp_attn_norm", il); // self-attention { ggml_tensor * Qcur = build_lora_mm(mtp_layer.wq, cur); + // if (mtp_layer.bq) { + // Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq); + // cb(Qcur, "mtp_q_bias", il); // ADICIONADO + // } if (mtp_layer.bq) Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq); cb(Qcur, "Qcur", il); ggml_tensor * Kcur = build_lora_mm(mtp_layer.wk, cur); + // if (mtp_layer.bk) { + // Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk); + // cb(Kcur, "mtp_k_bias", il); // ADICIONADO + // } if (mtp_layer.bk) Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk); cb(Kcur, "Kcur", il); ggml_tensor * Vcur = build_lora_mm(mtp_layer.wv, cur); + // if (mtp_layer.bv) { + // Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv); + // cb(Vcur, "mtp_v_bias", il); // ADICIONADO + // } if (mtp_layer.bv) Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv); cb(Vcur, "Vcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + // cb(Qcur, "mtp_q_reshaped", il); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + // cb(Kcur, "mtp_k_reshaped", il); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + // cb(Vcur, "mtp_v_reshaped", il); // Apply Q/K norm if available (GLM-4.5 355B variant) if (mtp_layer.attn_q_norm) { @@ -14023,12 +14033,14 @@ struct llm_build_glm4_moe : public llm_graph_context { n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); + // cb(Qcur, "mtp_q_rope", il); Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); + // cb(Kcur, "mtp_k_rope", il); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -14040,8 +14052,10 @@ struct llm_build_glm4_moe : public llm_graph_context { } ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + // cb(ffn_inp, "mtp_ffn_inp", il); cur = build_norm(ffn_inp, mtp_layer.attn_post_norm, NULL, LLM_NORM_RMS, il); + // cb(cur, "post_attn_norm", il); // moe ffn for nextn block { @@ -14073,7 +14087,10 @@ struct llm_build_glm4_moe : public llm_graph_context { cb(cur, "ffn_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); + // cb(cur, "mtp_ffn_residual", il); + cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il); + // cb(cur, "mtp_final_norm", il); cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur); return cur; @@ -18305,7 +18322,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { - + const int64_t t_start_us = ggml_time_us(); std::unique_ptr llm; switch (arch) { case LLM_ARCH_LLAMA: @@ -18668,6 +18685,12 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); } + const int64_t t_end_us = ggml_time_us(); + LLAMA_LOG_INFO( + "[PERF] Graph build time: %.2f ms (MTP path: %s)\n", + (t_end_us - t_start_us) / 1000.0, + params.mtp_params.op_type != MTP_OP_NONE ? "yes" : "no" + ); return llm->res->get_gf(); } diff --git a/tools/server/server.cpp b/tools/server/server.cpp index a24532c6939af..3db694653e08e 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1366,6 +1366,7 @@ struct server_slot { // Speculative decoding stats int32_t n_draft_total = 0; // Total draft tokens generated int32_t n_draft_accepted = 0; // Draft tokens actually accepted + llama_tokens ids_prev_accepted; void reset() { SLT_DBG(*this, "%s", "\n"); @@ -3468,7 +3469,10 @@ struct server_context { batch.logits + i, }; + const int64_t t_prompt_main_start_us = ggml_time_us(); const int ret = llama_decode(ctx, batch_view); + const int64_t t_prompt_main_end_us = ggml_time_us(); + LOG_INF("[PERF-PROMPT] Main model prompt processing: %.2f ms\n", (t_prompt_main_end_us - t_prompt_main_start_us) / 1000.0); metrics.on_decoded(slots); @@ -3516,7 +3520,10 @@ struct server_context { // from the main model's prompt processing pass. This ensures the MTP layer's // KV cache is perfectly aligned. if (llama_mtp_prepare_sinfo_for_warmup(ctx)) { + const int64_t t_warmup_start_us = ggml_time_us(); mtp_update_kv_cache(ctx, batch_view, true); + const int64_t t_warmup_end_us = ggml_time_us(); + LOG_INF("[PERF-PROMPT] MTP warm-up: %.2f ms\n", (t_warmup_end_us - t_warmup_start_us) / 1000.0); // Clean up the forced state to not affect subsequent decodes. llama_mtp_cancel_sinfo_update(ctx); } else { @@ -3558,10 +3565,7 @@ struct server_context { } const int tok_idx = slot.i_batch - i; - // Sets the initial state for the first draft generation. - if (slot.has_mtp) { - llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, -1)); - } + llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); slot.last_tok_idx = tok_idx; @@ -3636,10 +3640,17 @@ struct server_context { llama_token id = slot.sampled; llama_tokens draft; - if (slot.has_mtp) { - llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); - draft.reserve(1); - draft.push_back(draft_id); + // const int64_t t_spec_start_us = ggml_time_us(); + if (slot.has_mtp) { + const int32_t n_past_before_speculation = slot.n_past; + LOG_INF("[MTP-FLOW] Starting Update+Draft. n_past_base=%d, id_base=%d\n", n_past_before_speculation, id); + + llama_token draft_id = mtp_update_and_draft(slot.smpl, ctx, slot.ids_prev_accepted, n_past_before_speculation, id); + + if (draft_id >= 0) { + draft.push_back(draft_id); + } + slot.ids_prev_accepted.clear(); } else { struct common_speculative_params params_spec; @@ -3672,22 +3683,16 @@ struct server_context { } SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); + const int64_t t_valid_start_us = ggml_time_us(); llama_decode(ctx, slot.batch_spec); + const int64_t t_valid_end_us = ggml_time_us(); // the accepted tokens from the speculation + const int64_t t_accept_start_us = ggml_time_us(); const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); - - if (slot.has_mtp) { - llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, ids.size() - 1)); + const int64_t t_accept_end_us = ggml_time_us(); - if (!ids.empty()) { - llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, ids.size() - 1)); - } else { - llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, 0)); - } - - mtp_accept_tokens(ctx, ids, slot.n_past, slot.id); - } + slot.ids_prev_accepted = ids; slot.n_past += ids.size(); slot.n_decoded += ids.size();