| 
12 | 12 | #include <string>  | 
13 | 13 | #include <vector>  | 
14 | 14 | 
 
  | 
15 |  | -#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE  100  | 
 | 15 | +#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE  128  | 
16 | 16 | #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5  | 
17 | 17 | 
 
  | 
18 | 18 | struct seq_draft {  | 
@@ -188,6 +188,8 @@ int main(int argc, char ** argv) {  | 
188 | 188 |     // draft sequence data  | 
189 | 189 |     std::vector<seq_draft> drafts(n_seq_dft);  | 
190 | 190 | 
 
  | 
 | 191 | +    params.sparams.top_k = std::max(10, params.sparams.top_k);  | 
 | 192 | + | 
191 | 193 |     for (int s = 0; s < n_seq_dft; ++s) {  | 
192 | 194 |         // allocate llama_sampler for each draft sequence  | 
193 | 195 |         drafts[s].smpl = common_sampler_init(model_dft, params.sparams);  | 
@@ -346,6 +348,7 @@ int main(int argc, char ** argv) {  | 
346 | 348 |                         std::vector<float> probs(dist_tgt.size);  | 
347 | 349 |                         for (size_t i = 0; i < dist_tgt.size; ++i) {  | 
348 | 350 |                             probs[i] = dist_tgt.data[i].p;  | 
 | 351 | +                            LOG_DBG(" - %d: %f\n", dist_tgt.data[i].id, dist_tgt.data[i].p);  | 
349 | 352 |                         }  | 
350 | 353 | 
 
  | 
351 | 354 |                         std::discrete_distribution<> dist(probs.begin(), probs.end());  | 
@@ -449,10 +452,13 @@ int main(int argc, char ** argv) {  | 
449 | 452 |             break;  | 
450 | 453 |         }  | 
451 | 454 | 
 
  | 
452 |  | -        if (drafts[0].smpl) {  | 
453 |  | -            common_sampler_free(drafts[0].smpl);  | 
454 |  | -        }  | 
455 |  | -        drafts[0].smpl = common_sampler_clone(smpl);  | 
 | 455 | +        // TODO: this needs better fix - we want the draft samplers to have different parameters from the target sampler  | 
 | 456 | +        //       so we should not copy the target sampler  | 
 | 457 | +        //if (drafts[0].smpl) {  | 
 | 458 | +        //    common_sampler_free(drafts[0].smpl);  | 
 | 459 | +        //}  | 
 | 460 | +        //drafts[0].smpl = common_sampler_clone(smpl);  | 
 | 461 | +        common_sampler_reset(drafts[0].smpl);  | 
456 | 462 | 
 
  | 
457 | 463 |         int n_seq_cur  = 1;  | 
458 | 464 |         int n_past_cur = n_past_dft;  | 
@@ -540,6 +546,12 @@ int main(int argc, char ** argv) {  | 
540 | 546 | 
 
  | 
541 | 547 |                     const int s = sa[is];  | 
542 | 548 | 
 
  | 
 | 549 | +                    // only collect very high-confidence draft tokens  | 
 | 550 | +                    if (cur_p->data[is].p < 0.90) {  | 
 | 551 | +                        drafts[s].drafting = false;  | 
 | 552 | +                        continue;  | 
 | 553 | +                    }  | 
 | 554 | + | 
543 | 555 |                     common_sampler_accept(drafts[s].smpl, id, true);  | 
544 | 556 | 
 
  | 
545 | 557 |                     drafts[s].tokens.push_back(id);  | 
@@ -577,6 +589,12 @@ int main(int argc, char ** argv) {  | 
577 | 589 |             }  | 
578 | 590 |         }  | 
579 | 591 | 
 
  | 
 | 592 | +        // don't waste time on small batches  | 
 | 593 | +        if (batch_tgt.n_tokens < 5) {  | 
 | 594 | +            batch_tgt.n_tokens = 1;  | 
 | 595 | +            drafts[0].tokens.resize(batch_tgt.n_tokens);  | 
 | 596 | +        }  | 
 | 597 | + | 
580 | 598 |         // evaluate the target model on the drafted tokens  | 
581 | 599 |         {  | 
582 | 600 |             llama_kv_cache_seq_keep(ctx_tgt, 0);  | 
 | 
0 commit comments