Skip to content

Commit 67179f9

Browse files
committed
speculative : adapt to new llama API
1 parent dc4bb64 commit 67179f9

File tree

1 file changed

+35
-27
lines changed

1 file changed

+35
-27
lines changed

examples/speculative/speculative.cpp

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ int main(int argc, char ** argv) {
4545
}
4646

4747
common_init();
48-
#if 0
4948
if (params.speculative.model.empty()) {
5049
LOG_ERR("%s: --model-draft is required\n", __func__);
5150
return 1;
@@ -169,9 +168,9 @@ int main(int argc, char ** argv) {
169168
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
170169
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
171170
llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true));
172-
llama_decode_ext(ctx_tgt, batch0);
173-
llama_decode_ext(ctx_tgt, batch1);
174-
llama_decode_ext(ctx_dft, batch2);
171+
llama_decode_ext(ctx_tgt, batch0.get());
172+
llama_decode_ext(ctx_tgt, batch1.get());
173+
llama_decode_ext(ctx_dft, batch2.get());
175174

176175
const auto t_enc_end = ggml_time_us();
177176

@@ -338,7 +337,7 @@ int main(int argc, char ** argv) {
338337
if (i == s) {
339338
continue;
340339
}
341-
if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
340+
if (drafts[i].active && drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
342341
// synchronize active status for sequences with the same drafted token
343342
drafts[i].active = drafts[i].active && accept;
344343
if (!drafts[i].active) {
@@ -446,7 +445,7 @@ int main(int argc, char ** argv) {
446445

447446
llama_batch_ext_clear(batch_dft);
448447
llama_seq_id seq_id = 0;
449-
llama_batch_ext_add_text(batch_tgt, token_id, n_past_tgt, &seq_id, 1, true);
448+
llama_batch_ext_add_text(batch_dft, token_id, n_past_dft, &seq_id, 1, true);
450449

451450
llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1);
452451
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
@@ -475,13 +474,18 @@ int main(int argc, char ** argv) {
475474
drafts[0].drafting = true;
476475
drafts[0].i_batch_dft = 0;
477476

478-
llama_batch_ext_clear(batch_tgt);
479-
llama_seq_id seq_id = 0;
480-
llama_batch_ext_add_text(batch_tgt, drafts[0].tokens[0], n_past_tgt, &seq_id, 1, true);
477+
struct batch_info {
478+
llama_token id;
479+
std::vector<llama_seq_id> seq_id;
480+
};
481+
482+
std::vector<batch_info> batch_tgt_infos;
483+
484+
batch_tgt_infos.push_back({ drafts[0].tokens[0], {0} });
481485

482486
// sample n_draft tokens from the draft model using tree-based sampling
483487
for (int i = 0; i < n_draft; ++i) {
484-
batch_dft.n_tokens = 0;
488+
llama_batch_ext_clear(batch_dft);
485489

486490
for (int s = 0; s < n_seq_dft; ++s) {
487491
drafts[s].skip = false;
@@ -512,11 +516,10 @@ int main(int argc, char ** argv) {
512516
llama_kv_self_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
513517

514518
// all previous tokens from this branch are now also part of the new branch
515-
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
516-
for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
517-
if (batch_tgt.seq_id[t][p] == s) {
518-
batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur;
519-
batch_tgt.n_seq_id[t]++;
519+
for (int t = 0; t < (int) batch_tgt_infos.size(); ++t) {
520+
for (int p = 0; p < (int) batch_tgt_infos[t].seq_id.size(); ++p) {
521+
if (batch_tgt_infos[t].seq_id[p] == s) {
522+
batch_tgt_infos[t].seq_id.push_back(n_seq_cur);
520523
break;
521524
}
522525
}
@@ -558,32 +561,30 @@ int main(int argc, char ** argv) {
558561
drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size});
559562

560563
// add unique drafted tokens to the target batch
561-
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
564+
drafts[s].i_batch_tgt.push_back(batch_tgt_infos.size());
562565

563-
common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
566+
batch_tgt_infos.push_back({ id, { s }});
564567

565568
// add the token to the batch for batched decoding with the draft model
566-
drafts[s].i_batch_dft = batch_dft.n_tokens;
567-
568-
common_batch_add(batch_dft, id, n_past_cur, { s }, true);
569+
drafts[s].i_batch_dft = llama_batch_ext_add_text(batch_dft, id, n_past_cur, &s, 1, true);
569570

570-
if (batch_tgt.n_tokens > n_draft) {
571+
if (batch_tgt_infos.size() > (size_t) n_draft) {
571572
drafts[s].drafting = false;
572573
}
573574
}
574575
}
575576

576577
// no sequence is drafting anymore
577-
if (batch_dft.n_tokens == 0) {
578+
if (llama_batch_ext_get_n_tokens(batch_dft) == 0) {
578579
break;
579580
}
580581

581582
// evaluate the drafted tokens on the draft model
582-
llama_decode(ctx_dft, batch_dft);
583+
llama_decode_ext(ctx_dft, batch_dft);
583584
++n_past_cur;
584585
++n_drafted;
585586

586-
if (batch_tgt.n_tokens > n_draft) {
587+
if (batch_tgt_infos.size() > (size_t) n_draft) {
587588
break;
588589
}
589590
}
@@ -595,8 +596,15 @@ int main(int argc, char ** argv) {
595596
llama_kv_self_seq_cp(ctx_tgt, 0, s, -1, -1);
596597
}
597598

599+
llama_batch_ext_clear(batch_tgt);
600+
for (int i = 0; i < (int) batch_tgt_infos.size(); ++i) {
601+
const auto & info = batch_tgt_infos[i];
602+
603+
llama_batch_ext_add_text(batch_tgt, info.id, n_past_tgt + i, info.seq_id.data(), info.seq_id.size(), true);
604+
}
605+
598606
// LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
599-
llama_decode(ctx_tgt, batch_tgt);
607+
llama_decode_ext(ctx_tgt, batch_tgt);
600608
++n_past_tgt;
601609
}
602610

@@ -639,12 +647,12 @@ int main(int argc, char ** argv) {
639647
common_sampler_free(drafts[s].smpl);
640648
}
641649

642-
llama_batch_free(batch_dft);
650+
llama_batch_ext_free(batch_dft);
651+
llama_batch_ext_free(batch_tgt);
643652

644653
llama_backend_free();
645654

646655
LOG("\n\n");
647656

648-
#endif
649657
return 0;
650658
}

0 commit comments

Comments
 (0)