@@ -370,56 +370,35 @@ llama_token mtp_speculative_gen_draft(
370370 int32_t n_past,
371371 int32_t last_tok_idx) {
372372
373- llama_token token_data[] = { id_last };
374- llama_pos pos_data[] = { n_past };
375- int32_t n_seq_id_data[] = { 1 };
376- llama_seq_id seq_id_data_internal[] = { 0 };
377- llama_seq_id* seq_id_data[] = {seq_id_data_internal};
378- int8_t logits_data[] = { (int8_t ) (smpl != nullptr ) };
379-
380- llama_batch batch = {
381- /* .n_tokens = */ 1 ,
382- /* .token = */ token_data,
383- /* .embd = */ nullptr ,
384- /* .pos = */ pos_data,
385- /* .n_seq_id = */ n_seq_id_data,
386- /* .seq_id = */ seq_id_data,
387- /* .logits = */ logits_data
388- };
389-
390- return llama_build_and_execute_mtp_graph (ctx, batch, id_last, n_past, last_tok_idx);
391- // LOG_INF("updating kv cache for n_past: %d\n", n_past);
392-
393- /*
394373 if (!smpl) {
395374 return -1 ;
396375 }
397- else {
398- common_sampler_sample(smpl, ctx, last_tok_idx, true);
399- const auto* cur_p = common_sampler_get_candidates(smpl);
400376
401- //for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) {
402- // LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
403- // k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
404- //}
377+ llama_batch batch = llama_batch_init (1 , 0 , 1 );
378+ common_batch_add (batch, id_last, n_past, {0 }, true );
405379
406- const llama_token id = cur_p->data[0].id;
407- return id;
380+ llama_build_and_execute_mtp_graph (ctx, batch, id_last, n_past, last_tok_idx);
381+
382+ const llama_model * model = llama_get_model (ctx);
383+ const llama_vocab * vocab = llama_model_get_vocab (model);
384+ const int n_vocab = llama_n_vocab (vocab);
385+
386+ llama_token_data_array * cur_p = common_sampler_get_candidates (smpl);
387+
388+ cur_p->size = n_vocab;
389+ for (int i = 0 ; i < n_vocab; ++i) {
390+ cur_p->data [i].id = i;
391+ cur_p->data [i].logit = llama_get_logits_ith (ctx, last_tok_idx)[i];
408392 }
409- */
410- // LOG_INF("cur_p->size: %d\n", cur_p->size);
393+ cur_p->sorted = false ;
411394
395+ common_sampler_apply_chain (smpl, cur_p);
412396
413- // add drafted token for each sequence
397+ const llama_token id = cur_p-> data [ 0 ]. id ;
414398
415- // skip accepting draft token -- since we're only drafting one token this can't affect future outputs
416- // smpl will accept the token if it doesn't get rejected by main model later
417- // common_sampler_accept(smpl, id, true);
399+ llama_batch_free (batch);
418400
419- // llama_tokens result;
420- // result.reserve(1);
421- // result.push_back(id);
422- // return result;
401+ return id;
423402}
424403
425404
@@ -438,4 +417,4 @@ void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_d
438417 }
439418
440419 tokens.clear ();
441- }
420+ }
0 commit comments