@@ -794,28 +794,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
794794 }
795795
796796 if (mtp_params.op_type != MTP_OP_NONE) { // If it is any MTP operation
797- const char * target_tensor_name = " result_embd_pooled" ;
798- ggml_tensor* hidden_states_input = ggml_get_tensor (res->get_ctx (), target_tensor_name);
799-
800- const float * source_hidden_state = nullptr ;
801- if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
802- source_hidden_state = this ->embd ;
803- } else {
804- source_hidden_state = this ->draft_input_hidden_state ;
805- }
806-
807- if (source_hidden_state != nullptr && hidden_states_input != nullptr ) {
808- const size_t n_embd = this ->model .hparams .n_embd ;
809- const size_t n_tokens_for_sum = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED && ubatch.n_tokens > 2 ) ? ubatch.n_tokens : 1 ;
810- double input_sum = calculate_vector_sum (source_hidden_state, n_tokens_for_sum * n_embd);
811- const char * op_type = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) ? " MTP_UPDATE" : " DRAFT_GEN" ;
812-
813- LLAMA_LOG_WARN (" [MTP-INPUT-CHECK] Operation: %s | Input Checksum: %e\n " , op_type, input_sum);
814-
815- ggml_backend_tensor_set (hidden_states_input, source_hidden_state, 0 , ggml_nbytes (hidden_states_input));
816- } else {
817- LLAMA_LOG_ERROR (" %s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n " ,
818- __func__, target_tensor_name);
797+ if (!prepare_mtp_graph_inputs (res, ubatch, mtp_params)) {
819798 ret = GGML_STATUS_FAILED;
820799 return nullptr ;
821800 }
@@ -1089,27 +1068,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
10891068 std::unique_ptr<llama_memory_context_i> mctx;
10901069
10911070 while (true ) {
1092- if (cparams.warmup ) {
1093- mctx = memory->init_batch (*balloc, cparams.n_ubatch , output_all);
1094- } else {
1095- if (kvd->forced_sinfos && !kvd->forced_sinfos ->empty ()) {
1096- LLAMA_LOG_WARN (" [DEBUG-CACHE-REUSE] Forcing sinfos, bypassing find_slot.\n " );
1097-
1098- mctx = static_cast <llama_kv_cache_unified *>(memory.get ())->init_batch_with_sinfos (
1099- *balloc, cparams.n_ubatch , *kvd->forced_sinfos , true
1100- );
1101- } else {
1102- mctx = memory->init_batch (*balloc, cparams.n_ubatch , output_all);
1103-
1104- if (batch_inp.mtp_params .op_type == MTP_OP_NONE) {
1105- if (mctx && mctx->get_status () == LLAMA_MEMORY_STATUS_SUCCESS) {
1106- kvd->last_main_model_sinfos = static_cast <llama_kv_cache_unified_context *>(mctx.get ())->get_sinfos ();
1107- } else {
1108- kvd->last_main_model_sinfos .clear ();
1109- }
1110- }
1111- }
1112- }
1071+ mctx = this ->initialize_decode_context (batch_inp, output_all);
11131072
11141073 if (!mctx) {
11151074 return -2 ;
@@ -3149,3 +3108,77 @@ void llama_context::kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
31493108void llama_kv_cache_seq_rm (struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
31503109 ctx->kv_cache_seq_rm (seq_id, p0, p1);
31513110}
3111+
3112+ /*
3113+ Initializes the memory context for a decode operation.
3114+ The logic follows a specific priority:
3115+ 1. Warmup: Always use a standard batch initialization.
3116+ 2. Forced S-Info (MTP Updates): If a specific KV cache layout is forced, use it.
3117+ 3. Default: Use a standard batch initialization, and if it's a main model pass,
3118+ save the resulting s-info for potential future reuse by MTP.
3119+ */
3120+ std::unique_ptr<llama_memory_context_i> llama_context::initialize_decode_context (const llama_batch & batch_inp, const bool output_all) {
3121+ auto * kvd = static_cast <llama_context_kv_cache_data *>(kv_cache_data);
3122+ std::unique_ptr<llama_memory_context_i> mctx;
3123+
3124+ if (cparams.warmup ) {
3125+ mctx = memory->init_batch (*balloc, cparams.n_ubatch , output_all);
3126+ } else if (kvd->forced_sinfos && !kvd->forced_sinfos ->empty ()) {
3127+ LLAMA_LOG_WARN (" [DEBUG-CACHE-REUSE] Forcing sinfos, bypassing find_slot.\n " );
3128+ mctx = static_cast <llama_kv_cache_unified *>(memory.get ())->init_batch_with_sinfos (
3129+ *balloc, cparams.n_ubatch , *kvd->forced_sinfos , true
3130+ );
3131+ } else {
3132+ mctx = memory->init_batch (*balloc, cparams.n_ubatch , output_all);
3133+
3134+ if (batch_inp.mtp_params .op_type == MTP_OP_NONE) {
3135+ if (mctx && mctx->get_status () == LLAMA_MEMORY_STATUS_SUCCESS) {
3136+ kvd->last_main_model_sinfos = static_cast <llama_kv_cache_unified_context *>(mctx.get ())->get_sinfos ();
3137+ } else {
3138+ kvd->last_main_model_sinfos .clear ();
3139+ }
3140+ }
3141+ }
3142+
3143+ return mctx;
3144+ }
3145+
3146+
3147+ bool llama_context::prepare_mtp_graph_inputs (
3148+ llm_graph_result * res,
3149+ const llama_ubatch & ubatch,
3150+ const llama_mtp_params & mtp_params) {
3151+
3152+ const char * target_tensor_name = " result_embd_pooled" ;
3153+ ggml_tensor* hidden_states_input = ggml_get_tensor (res->get_ctx (), target_tensor_name);
3154+
3155+ const float * source_hidden_state = nullptr ;
3156+ if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
3157+ source_hidden_state = this ->embd ;
3158+ } else { // MTP_OP_DRAFT_GEN
3159+ source_hidden_state = this ->draft_input_hidden_state ;
3160+ }
3161+
3162+ if (source_hidden_state != nullptr && hidden_states_input != nullptr ) {
3163+ const size_t n_embd = this ->model .hparams .n_embd ;
3164+ const size_t n_tokens_for_sum = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED && ubatch.n_tokens > 2 ) ? ubatch.n_tokens : 1 ;
3165+ double input_sum = calculate_vector_sum (source_hidden_state, n_tokens_for_sum * n_embd);
3166+
3167+ const char * op_type;
3168+ if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
3169+ op_type = " MTP_UPDATE" ;
3170+ } else { // MTP_OP_DRAFT_GEN
3171+ op_type = " DRAFT_GEN" ;
3172+ }
3173+
3174+ LLAMA_LOG_WARN (" [MTP-INPUT-CHECK] Operation: %s | Input Checksum: %e\n " , op_type, input_sum);
3175+
3176+ ggml_backend_tensor_set (hidden_states_input, source_hidden_state, 0 , ggml_nbytes (hidden_states_input));
3177+ } else {
3178+ LLAMA_LOG_ERROR (" %s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n " ,
3179+ __func__, target_tensor_name);
3180+ return false ;
3181+ }
3182+
3183+ return true ;
3184+ }
0 commit comments