Skip to content

Commit a99709d

Browse files
mtp-batch(refactor): Extract decode context and MTP input logic into helper methods
1 parent 913af8f commit a99709d

File tree

2 files changed

+84
-43
lines changed

2 files changed

+84
-43
lines changed

src/llama-context.cpp

Lines changed: 76 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -794,28 +794,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
794794
}
795795

796796
if (mtp_params.op_type != MTP_OP_NONE) { // If it is any MTP operation
797-
const char * target_tensor_name = "result_embd_pooled";
798-
ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name);
799-
800-
const float * source_hidden_state = nullptr;
801-
if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
802-
source_hidden_state = this->embd;
803-
} else {
804-
source_hidden_state = this->draft_input_hidden_state;
805-
}
806-
807-
if (source_hidden_state != nullptr && hidden_states_input != nullptr) {
808-
const size_t n_embd = this->model.hparams.n_embd;
809-
const size_t n_tokens_for_sum = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED && ubatch.n_tokens > 2) ? ubatch.n_tokens : 1;
810-
double input_sum = calculate_vector_sum(source_hidden_state, n_tokens_for_sum * n_embd);
811-
const char * op_type = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) ? "MTP_UPDATE" : "DRAFT_GEN";
812-
813-
LLAMA_LOG_WARN("[MTP-INPUT-CHECK] Operation: %s | Input Checksum: %e\n", op_type, input_sum);
814-
815-
ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input));
816-
} else {
817-
LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n",
818-
__func__, target_tensor_name);
797+
if (!prepare_mtp_graph_inputs(res, ubatch, mtp_params)) {
819798
ret = GGML_STATUS_FAILED;
820799
return nullptr;
821800
}
@@ -1089,27 +1068,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
10891068
std::unique_ptr<llama_memory_context_i> mctx;
10901069

10911070
while (true) {
1092-
if (cparams.warmup) {
1093-
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
1094-
} else {
1095-
if (kvd->forced_sinfos && !kvd->forced_sinfos->empty()) {
1096-
LLAMA_LOG_WARN("[DEBUG-CACHE-REUSE] Forcing sinfos, bypassing find_slot.\n");
1097-
1098-
mctx = static_cast<llama_kv_cache_unified *>(memory.get())->init_batch_with_sinfos(
1099-
*balloc, cparams.n_ubatch, *kvd->forced_sinfos, true
1100-
);
1101-
} else {
1102-
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
1103-
1104-
if (batch_inp.mtp_params.op_type == MTP_OP_NONE) {
1105-
if (mctx && mctx->get_status() == LLAMA_MEMORY_STATUS_SUCCESS) {
1106-
kvd->last_main_model_sinfos = static_cast<llama_kv_cache_unified_context *>(mctx.get())->get_sinfos();
1107-
} else {
1108-
kvd->last_main_model_sinfos.clear();
1109-
}
1110-
}
1111-
}
1112-
}
1071+
mctx = this->initialize_decode_context(batch_inp, output_all);
11131072

11141073
if (!mctx) {
11151074
return -2;
@@ -3149,3 +3108,77 @@ void llama_context::kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
31493108
void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
31503109
ctx->kv_cache_seq_rm(seq_id, p0, p1);
31513110
}
3111+
3112+
/*
3113+
Initializes the memory context for a decode operation.
3114+
The logic follows a specific priority:
3115+
1. Warmup: Always use a standard batch initialization.
3116+
2. Forced S-Info (MTP Updates): If a specific KV cache layout is forced, use it.
3117+
3. Default: Use a standard batch initialization, and if it's a main model pass,
3118+
save the resulting s-info for potential future reuse by MTP.
3119+
*/
3120+
std::unique_ptr<llama_memory_context_i> llama_context::initialize_decode_context(const llama_batch & batch_inp, const bool output_all) {
3121+
auto * kvd = static_cast<llama_context_kv_cache_data *>(kv_cache_data);
3122+
std::unique_ptr<llama_memory_context_i> mctx;
3123+
3124+
if (cparams.warmup) {
3125+
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
3126+
} else if (kvd->forced_sinfos && !kvd->forced_sinfos->empty()) {
3127+
LLAMA_LOG_WARN("[DEBUG-CACHE-REUSE] Forcing sinfos, bypassing find_slot.\n");
3128+
mctx = static_cast<llama_kv_cache_unified *>(memory.get())->init_batch_with_sinfos(
3129+
*balloc, cparams.n_ubatch, *kvd->forced_sinfos, true
3130+
);
3131+
} else {
3132+
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
3133+
3134+
if (batch_inp.mtp_params.op_type == MTP_OP_NONE) {
3135+
if (mctx && mctx->get_status() == LLAMA_MEMORY_STATUS_SUCCESS) {
3136+
kvd->last_main_model_sinfos = static_cast<llama_kv_cache_unified_context *>(mctx.get())->get_sinfos();
3137+
} else {
3138+
kvd->last_main_model_sinfos.clear();
3139+
}
3140+
}
3141+
}
3142+
3143+
return mctx;
3144+
}
3145+
3146+
3147+
bool llama_context::prepare_mtp_graph_inputs(
3148+
llm_graph_result * res,
3149+
const llama_ubatch & ubatch,
3150+
const llama_mtp_params & mtp_params) {
3151+
3152+
const char * target_tensor_name = "result_embd_pooled";
3153+
ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name);
3154+
3155+
const float * source_hidden_state = nullptr;
3156+
if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
3157+
source_hidden_state = this->embd;
3158+
} else { // MTP_OP_DRAFT_GEN
3159+
source_hidden_state = this->draft_input_hidden_state;
3160+
}
3161+
3162+
if (source_hidden_state != nullptr && hidden_states_input != nullptr) {
3163+
const size_t n_embd = this->model.hparams.n_embd;
3164+
const size_t n_tokens_for_sum = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED && ubatch.n_tokens > 2) ? ubatch.n_tokens : 1;
3165+
double input_sum = calculate_vector_sum(source_hidden_state, n_tokens_for_sum * n_embd);
3166+
3167+
const char * op_type;
3168+
if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
3169+
op_type = "MTP_UPDATE";
3170+
} else { // MTP_OP_DRAFT_GEN
3171+
op_type = "DRAFT_GEN";
3172+
}
3173+
3174+
LLAMA_LOG_WARN("[MTP-INPUT-CHECK] Operation: %s | Input Checksum: %e\n", op_type, input_sum);
3175+
3176+
ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input));
3177+
} else {
3178+
LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n",
3179+
__func__, target_tensor_name);
3180+
return false;
3181+
}
3182+
3183+
return true;
3184+
}

src/llama-context.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,14 @@ struct llama_context {
231231

232232
llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const;
233233

234+
// Methods for MTP decode
235+
std::unique_ptr<llama_memory_context_i> initialize_decode_context(const llama_batch & batch_inp, const bool output_all);
236+
237+
bool prepare_mtp_graph_inputs(
238+
llm_graph_result * res,
239+
const llama_ubatch & ubatch,
240+
const llama_mtp_params & mtp_params);
241+
234242
// TODO: read/write lora adapters and cvec
235243
size_t state_write_data(llama_io_write_i & io);
236244
size_t state_read_data (llama_io_read_i & io);

0 commit comments

Comments
 (0)