Skip to content

Commit 23d7407

Browse files
authored
Merge pull request #15 from ggml-org/xsn/private_batch_api
speculative : adapt to new llama API
2 parents dc4bb64 + 7a3c178 commit 23d7407

File tree

1 file changed

+36
-27
lines changed

1 file changed

+36
-27
lines changed

examples/speculative/speculative.cpp

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ int main(int argc, char ** argv) {
4545
}
4646

4747
common_init();
48-
#if 0
4948
if (params.speculative.model.empty()) {
5049
LOG_ERR("%s: --model-draft is required\n", __func__);
5150
return 1;
@@ -169,9 +168,9 @@ int main(int argc, char ** argv) {
169168
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
170169
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
171170
llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true));
172-
llama_decode_ext(ctx_tgt, batch0);
173-
llama_decode_ext(ctx_tgt, batch1);
174-
llama_decode_ext(ctx_dft, batch2);
171+
llama_decode_ext(ctx_tgt, batch0.get());
172+
llama_decode_ext(ctx_tgt, batch1.get());
173+
llama_decode_ext(ctx_dft, batch2.get());
175174

176175
const auto t_enc_end = ggml_time_us();
177176

@@ -338,7 +337,7 @@ int main(int argc, char ** argv) {
338337
if (i == s) {
339338
continue;
340339
}
341-
if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
340+
if (drafts[i].active && drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
342341
// synchronize active status for sequences with the same drafted token
343342
drafts[i].active = drafts[i].active && accept;
344343
if (!drafts[i].active) {
@@ -446,7 +445,7 @@ int main(int argc, char ** argv) {
446445

447446
llama_batch_ext_clear(batch_dft);
448447
llama_seq_id seq_id = 0;
449-
llama_batch_ext_add_text(batch_tgt, token_id, n_past_tgt, &seq_id, 1, true);
448+
llama_batch_ext_add_text(batch_dft, token_id, n_past_dft, &seq_id, 1, true);
450449

451450
llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1);
452451
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
@@ -475,13 +474,19 @@ int main(int argc, char ** argv) {
475474
drafts[0].drafting = true;
476475
drafts[0].i_batch_dft = 0;
477476

478-
llama_batch_ext_clear(batch_tgt);
479-
llama_seq_id seq_id = 0;
480-
llama_batch_ext_add_text(batch_tgt, drafts[0].tokens[0], n_past_tgt, &seq_id, 1, true);
477+
struct batch_info {
478+
llama_token id;
479+
llama_pos pos;
480+
std::vector<llama_seq_id> seq_id;
481+
};
482+
483+
std::vector<batch_info> batch_tgt_data;
484+
485+
batch_tgt_data.push_back({ drafts[0].tokens[0], n_past_tgt, {0} });
481486

482487
// sample n_draft tokens from the draft model using tree-based sampling
483488
for (int i = 0; i < n_draft; ++i) {
484-
batch_dft.n_tokens = 0;
489+
llama_batch_ext_clear(batch_dft);
485490

486491
for (int s = 0; s < n_seq_dft; ++s) {
487492
drafts[s].skip = false;
@@ -512,11 +517,10 @@ int main(int argc, char ** argv) {
512517
llama_kv_self_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
513518

514519
// all previous tokens from this branch are now also part of the new branch
515-
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
516-
for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
517-
if (batch_tgt.seq_id[t][p] == s) {
518-
batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur;
519-
batch_tgt.n_seq_id[t]++;
520+
for (int t = 0; t < (int) batch_tgt_data.size(); ++t) {
521+
for (int p = 0; p < (int) batch_tgt_data[t].seq_id.size(); ++p) {
522+
if (batch_tgt_data[t].seq_id[p] == s) {
523+
batch_tgt_data[t].seq_id.push_back(n_seq_cur);
520524
break;
521525
}
522526
}
@@ -558,32 +562,30 @@ int main(int argc, char ** argv) {
558562
drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size});
559563

560564
// add unique drafted tokens to the target batch
561-
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
565+
drafts[s].i_batch_tgt.push_back(batch_tgt_data.size());
562566

563-
common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
567+
batch_tgt_data.push_back({ id, n_past_tgt + i + 1, { s }});
564568

565569
// add the token to the batch for batched decoding with the draft model
566-
drafts[s].i_batch_dft = batch_dft.n_tokens;
567-
568-
common_batch_add(batch_dft, id, n_past_cur, { s }, true);
570+
drafts[s].i_batch_dft = llama_batch_ext_add_text(batch_dft, id, n_past_cur, &s, 1, true);
569571

570-
if (batch_tgt.n_tokens > n_draft) {
572+
if (batch_tgt_data.size() > (size_t) n_draft) {
571573
drafts[s].drafting = false;
572574
}
573575
}
574576
}
575577

576578
// no sequence is drafting anymore
577-
if (batch_dft.n_tokens == 0) {
579+
if (llama_batch_ext_get_n_tokens(batch_dft) == 0) {
578580
break;
579581
}
580582

581583
// evaluate the drafted tokens on the draft model
582-
llama_decode(ctx_dft, batch_dft);
584+
llama_decode_ext(ctx_dft, batch_dft);
583585
++n_past_cur;
584586
++n_drafted;
585587

586-
if (batch_tgt.n_tokens > n_draft) {
588+
if (batch_tgt_data.size() > (size_t) n_draft) {
587589
break;
588590
}
589591
}
@@ -595,8 +597,15 @@ int main(int argc, char ** argv) {
595597
llama_kv_self_seq_cp(ctx_tgt, 0, s, -1, -1);
596598
}
597599

600+
llama_batch_ext_clear(batch_tgt);
601+
for (int i = 0; i < (int) batch_tgt_data.size(); ++i) {
602+
const auto & data = batch_tgt_data[i];
603+
604+
llama_batch_ext_add_text(batch_tgt, data.id, data.pos, data.seq_id.data(), data.seq_id.size(), true);
605+
}
606+
598607
// LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
599-
llama_decode(ctx_tgt, batch_tgt);
608+
llama_decode_ext(ctx_tgt, batch_tgt);
600609
++n_past_tgt;
601610
}
602611

@@ -639,12 +648,12 @@ int main(int argc, char ** argv) {
639648
common_sampler_free(drafts[s].smpl);
640649
}
641650

642-
llama_batch_free(batch_dft);
651+
llama_batch_ext_free(batch_dft);
652+
llama_batch_ext_free(batch_tgt);
643653

644654
llama_backend_free();
645655

646656
LOG("\n\n");
647657

648-
#endif
649658
return 0;
650659
}

0 commit comments

Comments
 (0)