@@ -373,48 +373,85 @@ llama_token mtp_speculative_gen_draft(
373373 if (!smpl) {
374374 return -1 ;
375375 }
376+ llama_batch mtp_batch = llama_batch_init (1 , 0 , 1 );
377+ const llama_pos draft_pos = n_past;
378+ const llama_seq_id draft_seq_id = 0 ;
379+ common_batch_add (mtp_batch, id_last, n_past, {0 }, true );
376380
377- llama_batch batch = llama_batch_init (1 , 0 , 1 );
378- common_batch_add (batch, id_last, n_past, {0 }, true );
381+ mtp_batch.mtp_params .op_type = MTP_OP_DRAFT_GEN;
379382
380- llama_build_and_execute_mtp_graph (ctx, batch, id_last, n_past, last_tok_idx);
383+ // Perform the MTP draft generation decode. This writes the MTP layer's
384+ // KV state for the draft token into the cache.
385+ llama_decode (ctx, mtp_batch);
386+ llama_batch_free (mtp_batch);
387+
388+ // CRITICAL: Purge the metadata for the draft token we just wrote.
389+ // This makes the physical cell available again for the main model's validation pass,
390+ // preventing a cache state corruption where two cells map to the same logical position.
391+ llama_kv_cache_seq_rm (ctx, draft_seq_id, draft_pos, draft_pos + 1 );
381392
382393 const llama_model * model = llama_get_model (ctx);
383394 const llama_vocab * vocab = llama_model_get_vocab (model);
384395 const int n_vocab = llama_n_vocab (vocab);
385-
386396 llama_token_data_array * cur_p = common_sampler_get_candidates (smpl);
387-
388397 cur_p->size = n_vocab;
389398 for (int i = 0 ; i < n_vocab; ++i) {
390399 cur_p->data [i].id = i;
391- cur_p->data [i].logit = llama_get_logits_ith (ctx, last_tok_idx )[i];
400+ cur_p->data [i].logit = llama_get_logits_ith (ctx, 0 )[i]; // For a single-token batch, logits are always at index 0.
392401 }
393402 cur_p->sorted = false ;
394-
395403 common_sampler_apply_chain (smpl, cur_p);
404+
405+ return cur_p->data [0 ].id ;
406+ }
396407
397- const llama_token id = cur_p->data [0 ].id ;
398408
399- llama_batch_free (batch);
409+ void mtp_update_kv_cache (struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) {
410+ if (batch.n_tokens == 0 ) {
411+ return ;
412+ }
400413
401- return id;
402- }
414+ LOG_DBG (" [MTP-UPDATE|%s] Updating %d tokens...\n " , is_prompt_warmup ? " PROMPT_WARMUP" : " GEN_ACCEPTED" , batch.n_tokens );
403415
416+ llama_batch mtp_batch = batch;
417+ if (is_prompt_warmup) {
418+ mtp_batch.mtp_params .op_type = MTP_OP_WARMUP;
419+ } else {
420+ mtp_batch.mtp_params .op_type = MTP_OP_UPDATE_ACCEPTED;
421+ }
404422
405- void mtp_update_kv_cache (struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start, size_t n_tokens) {
406- mtp_kv_update_data token;
423+ for (int i = 0 ; i < mtp_batch.n_tokens ; ++i) {
424+ mtp_batch.logits [i] = true ;
425+ }
426+ llama_decode (ctx, mtp_batch);
427+ }
407428
408- if (n_tokens < 0 ) {
409- n_tokens = tokens.size ();
429+ void mtp_accept_tokens (
430+ struct llama_context * ctx,
431+ const std::vector<llama_token> & ids,
432+ int32_t n_past_base,
433+ llama_seq_id seq_id
434+ ) {
435+ if (ids.empty ()) {
436+ return ;
410437 }
411438
412- for (int i = 0 ; i < std::min (tokens.size (), n_tokens); ++i) {
413- token = tokens[i];
414- // fprintf(stderr, "updating mtp kv cache with token (%d, %d, %d)\n", token.id, token.n_past, (int) (token.tok_idx - batch_start));
439+ // Prepare a resized copy of the validation sinfo to match the number of accepted tokens.
440+ // This sets up the context for a "forced sinfo" decode.
441+ if (!llama_mtp_prepare_sinfo_for_update (ctx, ids.size ())) {
442+ return ;
443+ }
415444
416- mtp_speculative_gen_draft (nullptr , ctx, token.id , token.n_past , token.tok_idx - batch_start);
445+ // Build a new batch containing only the accepted tokens.
446+ llama_batch accepted_batch = llama_batch_init (ids.size (), 0 , 1 );
447+ for (size_t i = 0 ; i < ids.size (); ++i) {
448+ common_batch_add (accepted_batch, ids[i], n_past_base + i, { seq_id }, true );
417449 }
418450
419- tokens.clear ();
420- }
451+ mtp_update_kv_cache (ctx, accepted_batch, false );
452+
453+ // Clean up the forced state to not affect subsequent, normal decode calls.
454+ llama_mtp_cancel_sinfo_update (ctx);
455+
456+ llama_batch_free (accepted_batch);
457+ }
0 commit comments