@@ -45,7 +45,6 @@ int main(int argc, char ** argv) {
4545 }
4646
4747 common_init ();
48- #if 0
4948 if (params.speculative .model .empty ()) {
5049 LOG_ERR (" %s: --model-draft is required\n " , __func__);
5150 return 1 ;
@@ -169,9 +168,9 @@ int main(int argc, char ** argv) {
169168 llama_batch_ext_ptr batch0 (llama_batch_ext_init_from_text ( inp.data (), n_input - 1 , 0 , 0 , true ));
170169 llama_batch_ext_ptr batch1 (llama_batch_ext_init_from_text (&inp.back (), 1 , n_input - 1 , 0 , true ));
171170 llama_batch_ext_ptr batch2 (llama_batch_ext_init_from_text ( inp.data (), n_input , 0 , 0 , true ));
172- llama_decode_ext(ctx_tgt, batch0);
173- llama_decode_ext(ctx_tgt, batch1);
174- llama_decode_ext(ctx_dft, batch2);
171+ llama_decode_ext (ctx_tgt, batch0. get () );
172+ llama_decode_ext (ctx_tgt, batch1. get () );
173+ llama_decode_ext (ctx_dft, batch2. get () );
175174
176175 const auto t_enc_end = ggml_time_us ();
177176
@@ -338,7 +337,7 @@ int main(int argc, char ** argv) {
338337 if (i == s) {
339338 continue ;
340339 }
341- if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
340+ if (drafts[i].active && drafts[i]. tokens [i_dft] == drafts[s].tokens [i_dft]) {
342341 // synchronize active status for sequences with the same drafted token
343342 drafts[i].active = drafts[i].active && accept;
344343 if (!drafts[i].active ) {
@@ -446,7 +445,7 @@ int main(int argc, char ** argv) {
446445
447446 llama_batch_ext_clear (batch_dft);
448447 llama_seq_id seq_id = 0 ;
449- llama_batch_ext_add_text(batch_tgt , token_id, n_past_tgt , &seq_id, 1, true);
448+ llama_batch_ext_add_text (batch_dft , token_id, n_past_dft , &seq_id, 1 , true );
450449
451450 llama_kv_self_seq_rm (ctx_dft, 0 , n_past_dft, -1 );
452451 // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
@@ -475,13 +474,19 @@ int main(int argc, char ** argv) {
475474 drafts[0 ].drafting = true ;
476475 drafts[0 ].i_batch_dft = 0 ;
477476
478- llama_batch_ext_clear(batch_tgt);
479- llama_seq_id seq_id = 0;
480- llama_batch_ext_add_text(batch_tgt, drafts[0].tokens[0], n_past_tgt, &seq_id, 1, true);
477+ struct batch_info {
478+ llama_token id;
479+ llama_pos pos;
480+ std::vector<llama_seq_id> seq_id;
481+ };
482+
483+ std::vector<batch_info> batch_tgt_data;
484+
485+ batch_tgt_data.push_back ({ drafts[0 ].tokens [0 ], n_past_tgt, {0 } });
481486
482487 // sample n_draft tokens from the draft model using tree-based sampling
483488 for (int i = 0 ; i < n_draft; ++i) {
484- batch_dft.n_tokens = 0 ;
489+ llama_batch_ext_clear ( batch_dft) ;
485490
486491 for (int s = 0 ; s < n_seq_dft; ++s) {
487492 drafts[s].skip = false ;
@@ -512,11 +517,10 @@ int main(int argc, char ** argv) {
512517 llama_kv_self_seq_cp (ctx_dft, s, n_seq_cur, -1 , -1 );
513518
514519 // all previous tokens from this branch are now also part of the new branch
515- for (int t = 0; t < batch_tgt.n_tokens; ++t) {
516- for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
517- if (batch_tgt.seq_id[t][p] == s) {
518- batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur;
519- batch_tgt.n_seq_id[t]++;
520+ for (int t = 0 ; t < (int ) batch_tgt_data.size (); ++t) {
521+ for (int p = 0 ; p < (int ) batch_tgt_data[t].seq_id .size (); ++p) {
522+ if (batch_tgt_data[t].seq_id [p] == s) {
523+ batch_tgt_data[t].seq_id .push_back (n_seq_cur);
520524 break ;
521525 }
522526 }
@@ -558,32 +562,30 @@ int main(int argc, char ** argv) {
558562 drafts[s].dists .push_back ({cur_p->data , cur_p->data + cur_p->size });
559563
560564 // add unique drafted tokens to the target batch
561- drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens );
565+ drafts[s].i_batch_tgt .push_back (batch_tgt_data. size () );
562566
563- common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true );
567+ batch_tgt_data. push_back ({ id, n_past_tgt + i + 1 , { s }} );
564568
565569 // add the token to the batch for batched decoding with the draft model
566- drafts[s].i_batch_dft = batch_dft.n_tokens;
567-
568- common_batch_add(batch_dft, id, n_past_cur, { s }, true);
570+ drafts[s].i_batch_dft = llama_batch_ext_add_text (batch_dft, id, n_past_cur, &s, 1 , true );
569571
570- if (batch_tgt.n_tokens > n_draft) {
572+ if (batch_tgt_data. size () > ( size_t ) n_draft) {
571573 drafts[s].drafting = false ;
572574 }
573575 }
574576 }
575577
576578 // no sequence is drafting anymore
577- if (batch_dft.n_tokens == 0) {
579+ if (llama_batch_ext_get_n_tokens ( batch_dft) == 0 ) {
578580 break ;
579581 }
580582
581583 // evaluate the drafted tokens on the draft model
582- llama_decode (ctx_dft, batch_dft);
584+ llama_decode_ext (ctx_dft, batch_dft);
583585 ++n_past_cur;
584586 ++n_drafted;
585587
586- if (batch_tgt.n_tokens > n_draft) {
588+ if (batch_tgt_data. size () > ( size_t ) n_draft) {
587589 break ;
588590 }
589591 }
@@ -595,8 +597,15 @@ int main(int argc, char ** argv) {
595597 llama_kv_self_seq_cp (ctx_tgt, 0 , s, -1 , -1 );
596598 }
597599
600+ llama_batch_ext_clear (batch_tgt);
601+ for (int i = 0 ; i < (int ) batch_tgt_data.size (); ++i) {
602+ const auto & data = batch_tgt_data[i];
603+
604+ llama_batch_ext_add_text (batch_tgt, data.id , data.pos , data.seq_id .data (), data.seq_id .size (), true );
605+ }
606+
598607 // LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
599- llama_decode (ctx_tgt, batch_tgt);
608+ llama_decode_ext (ctx_tgt, batch_tgt);
600609 ++n_past_tgt;
601610 }
602611
@@ -639,12 +648,12 @@ int main(int argc, char ** argv) {
639648 common_sampler_free (drafts[s].smpl );
640649 }
641650
642- llama_batch_free(batch_dft);
651+ llama_batch_ext_free (batch_dft);
652+ llama_batch_ext_free (batch_tgt);
643653
644654 llama_backend_free ();
645655
646656 LOG (" \n\n " );
647657
648- #endif
649658 return 0 ;
650659}
0 commit comments