@@ -1155,16 +1155,25 @@ int llama_context::decode(const llama_batch & batch_inp) {
11551155
11561156 // extract logits
11571157 if (t_logits && n_outputs > 0 ) {
1158- ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend (sched.get (), t_logits);
1159- GGML_ASSERT (backend_res != nullptr );
1160- GGML_ASSERT (logits != nullptr );
1161-
1162- float * logits_out = logits + n_outputs_prev*n_vocab;
1163-
1164- if (n_outputs) {
1165- GGML_ASSERT ( n_outputs_prev + n_outputs <= n_outputs_all);
1166- GGML_ASSERT ((n_outputs_prev + n_outputs)*n_vocab <= (int64_t ) logits_size);
1167- ggml_backend_tensor_get_async (backend_res, t_logits, logits_out, 0 , n_outputs*n_vocab*sizeof (float ));
1158+ // MTP operations that are purely for updating the KV cache
1159+ // (MTP_OP_WARMUP and MTP_OP_UPDATE_ACCEPTED) also produce a logit tensor
1160+ // as a side effect of running the graph. If these logits are copied
1161+ // back to the main context buffer, they will overwrite the valid logits
1162+ // produced by the main model's pass, leading to incorrect sampling.
1163+ // This condition explicitly prevents that copy for cache-only operations.
1164+ if (batch_inp.mtp_params .op_type != MTP_OP_WARMUP &&
1165+ batch_inp.mtp_params .op_type != MTP_OP_UPDATE_ACCEPTED) {
1166+ ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend (sched.get (), t_logits);
1167+ GGML_ASSERT (backend_res != nullptr );
1168+ GGML_ASSERT (logits != nullptr );
1169+
1170+ float * logits_out = logits + n_outputs_prev*n_vocab;
1171+
1172+ if (n_outputs) {
1173+ GGML_ASSERT ( n_outputs_prev + n_outputs <= n_outputs_all);
1174+ GGML_ASSERT ((n_outputs_prev + n_outputs)*n_vocab <= (int64_t ) logits_size);
1175+ ggml_backend_tensor_get_async (backend_res, t_logits, logits_out, 0 , n_outputs*n_vocab*sizeof (float ));
1176+ }
11681177 }
11691178 }
11701179
0 commit comments