Skip to content

Commit cae85fe

Browse files
mtp-batch(fix): avoid logits for mtp kv cache operations
1 parent 0127c6b commit cae85fe

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

src/llama-context.cpp

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,16 +1155,25 @@ int llama_context::decode(const llama_batch & batch_inp) {
11551155

11561156
// extract logits
11571157
if (t_logits && n_outputs > 0) {
1158-
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
1159-
GGML_ASSERT(backend_res != nullptr);
1160-
GGML_ASSERT(logits != nullptr);
1161-
1162-
float * logits_out = logits + n_outputs_prev*n_vocab;
1163-
1164-
if (n_outputs) {
1165-
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
1166-
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
1167-
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
1158+
// MTP operations that are purely for updating the KV cache
1159+
// (MTP_OP_WARMUP and MTP_OP_UPDATE_ACCEPTED) also produce a logit tensor
1160+
// as a side effect of running the graph. If these logits are copied
1161+
// back to the main context buffer, they will overwrite the valid logits
1162+
// produced by the main model's pass, leading to incorrect sampling.
1163+
// This condition explicitly prevents that copy for cache-only operations.
1164+
if (batch_inp.mtp_params.op_type != MTP_OP_WARMUP &&
1165+
batch_inp.mtp_params.op_type != MTP_OP_UPDATE_ACCEPTED) {
1166+
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
1167+
GGML_ASSERT(backend_res != nullptr);
1168+
GGML_ASSERT(logits != nullptr);
1169+
1170+
float * logits_out = logits + n_outputs_prev*n_vocab;
1171+
1172+
if (n_outputs) {
1173+
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
1174+
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
1175+
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
1176+
}
11681177
}
11691178
}
11701179

0 commit comments

Comments
 (0)