Skip to content

Commit 5e6dad9

Browse files
committed
speculative : experimenting with Qwen2.5
1 parent 33bdee6 commit 5e6dad9

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

examples/speculative/speculative.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include <string>
1313
#include <vector>
1414

15-
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
15+
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
1616
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
1717

1818
struct seq_draft {
@@ -188,6 +188,8 @@ int main(int argc, char ** argv) {
188188
// draft sequence data
189189
std::vector<seq_draft> drafts(n_seq_dft);
190190

191+
params.sparams.top_k = std::max(10, params.sparams.top_k);
192+
191193
for (int s = 0; s < n_seq_dft; ++s) {
192194
// allocate llama_sampler for each draft sequence
193195
drafts[s].smpl = common_sampler_init(model_dft, params.sparams);
@@ -346,6 +348,7 @@ int main(int argc, char ** argv) {
346348
std::vector<float> probs(dist_tgt.size);
347349
for (size_t i = 0; i < dist_tgt.size; ++i) {
348350
probs[i] = dist_tgt.data[i].p;
351+
LOG_DBG(" - %d: %f\n", dist_tgt.data[i].id, dist_tgt.data[i].p);
349352
}
350353

351354
std::discrete_distribution<> dist(probs.begin(), probs.end());
@@ -449,10 +452,13 @@ int main(int argc, char ** argv) {
449452
break;
450453
}
451454

452-
if (drafts[0].smpl) {
453-
common_sampler_free(drafts[0].smpl);
454-
}
455-
drafts[0].smpl = common_sampler_clone(smpl);
455+
// TODO: this needs better fix - we want the draft samplers to have different parameters from the target sampler
456+
// so we should not copy the target sampler
457+
//if (drafts[0].smpl) {
458+
// common_sampler_free(drafts[0].smpl);
459+
//}
460+
//drafts[0].smpl = common_sampler_clone(smpl);
461+
common_sampler_reset(drafts[0].smpl);
456462

457463
int n_seq_cur = 1;
458464
int n_past_cur = n_past_dft;
@@ -540,6 +546,12 @@ int main(int argc, char ** argv) {
540546

541547
const int s = sa[is];
542548

549+
// only collect very high-confidence draft tokens
550+
if (cur_p->data[is].p < 0.90) {
551+
drafts[s].drafting = false;
552+
continue;
553+
}
554+
543555
common_sampler_accept(drafts[s].smpl, id, true);
544556

545557
drafts[s].tokens.push_back(id);
@@ -577,6 +589,12 @@ int main(int argc, char ** argv) {
577589
}
578590
}
579591

592+
// don't waste time on small batches
593+
if (batch_tgt.n_tokens < 5) {
594+
batch_tgt.n_tokens = 1;
595+
drafts[0].tokens.resize(batch_tgt.n_tokens);
596+
}
597+
580598
// evaluate the target model on the drafted tokens
581599
{
582600
llama_kv_cache_seq_keep(ctx_tgt, 0);

0 commit comments

Comments
 (0)