1313int main (int argc, char ** argv) {
1414 common_params params;
1515
16+ // minimum size of the draft to use
17+ const int n_min = 5 ;
18+
1619 if (!common_params_parse (argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
1720 return 1 ;
1821 }
@@ -92,31 +95,29 @@ int main(int argc, char ** argv) {
9295 // everything until here is standard initialization
9396 // the relevant stuff for speculative decoding starts here
9497
95- const int n_input = inp.size ();
96-
9798 const auto t_enc_start = ggml_time_us ();
9899
99100 // target model sampling context
100101 struct common_sampler * smpl = common_sampler_init (model_tgt, params.sparams );
101102
102103 // eval the prompt
103- llama_decode (ctx_tgt, llama_batch_get_one (inp.data (), n_input - 1 ));
104+ llama_decode (ctx_tgt, llama_batch_get_one (inp.data (), inp. size () - 1 ));
104105
105106 // note: keep the last token separate!
106107 llama_token id_last = inp.back ();
107108
108- auto prompt_dft = std::vector<llama_token>(inp.begin (), inp.end () - 1 );
109+ // all tokens currently in the target context
110+ auto prompt_tgt = std::vector<llama_token>(inp.begin (), inp.end () - 1 );
109111
110112 int n_past = inp.size () - 1 ;
111113
112114 // init the speculator
113115 struct common_speculative_params params_spec;
114116 params_spec.n_draft = n_draft;
115- params_spec.n_min = 5 ;
116117 params_spec.n_reuse = 256 ;
117118 params_spec.p_min = 0 .9f ;
118119
119- struct common_speculative * spec = common_speculative_init (params_spec, ctx_dft);
120+ struct common_speculative * spec = common_speculative_init (ctx_dft);
120121
121122 llama_batch batch_tgt = llama_batch_init (llama_n_batch (ctx_tgt), 0 , 1 );
122123
@@ -125,21 +126,30 @@ int main(int argc, char ** argv) {
125126 const auto t_dec_start = ggml_time_us ();
126127
127128 while (true ) {
128- // always have a token to evaluate from before
129- common_batch_clear (batch_tgt);
130- common_batch_add (batch_tgt, id_last, n_past, { 0 }, true );
131-
132- // optionally, append draft tokens to the target batch
129+ // optionally, generate draft tokens that can be appended to the target batch
133130 //
134131 // this is the most important part of the speculation. the more probable tokens that are provided here
135132 // the better the performance will be. in theory, this computation can be performed asynchronously and even
136133 // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
137134 // from a cache or lookup tables.
138135 //
139- common_speculative_add_draft (spec, batch_tgt, prompt_dft, id_last, n_past + 1 );
136+ llama_tokens draft = common_speculative_gen_draft (spec, params_spec, prompt_tgt, id_last);
137+
138+ // always have a token to evaluate from before - id_last
139+ common_batch_clear (batch_tgt);
140+ common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true );
140141
141142 // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
142143 {
144+ // do not waste time on small drafts
145+ if (draft.size () < n_min) {
146+ draft.clear ();
147+ }
148+
149+ for (size_t i = 0 ; i < draft.size (); ++i) {
150+ common_batch_add (batch_tgt, draft[i], n_past + i, { 0 }, true );
151+ }
152+
143153 // LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
144154
145155 llama_decode (ctx_tgt, batch_tgt);
@@ -152,11 +162,11 @@ int main(int argc, char ** argv) {
152162 // available logits from the batch and sample the next token until we run out of logits or the sampler
153163 // disagrees with the draft
154164 //
155- const auto ids = common_sampler_sample_n (smpl, ctx_tgt, batch_tgt );
165+ const auto ids = common_sampler_sample_n (smpl, ctx_tgt, draft );
156166
157167 GGML_ASSERT (ids.size () > 0 ); // there will always be at least one accepted token
158168
159- n_past += ids.size ();
169+ n_past += ids.size () - 1 ;
160170 n_drafted += batch_tgt.n_tokens - 1 ;
161171 n_accept += ids.size () - 1 ;
162172
@@ -192,16 +202,16 @@ int main(int argc, char ** argv) {
192202 break ;
193203 }
194204
195- LOG_DBG (" accepted %d draft tokens, the last target token is: (%d, '%s')\n " , (int ) ids.size () - 1 , id, token_str.c_str ());
205+ LOG_DBG (" accepted %d/%d draft tokens, the last target token is: (%d, '%s')\n " , (int ) ids.size () - 1 , ( int ) draft. size () , id, token_str.c_str ());
196206
197207 {
198208 LOG_DBG (" clear kv cache from any extra tokens, n_past = %d\n " , n_past);
199209
200210 llama_kv_cache_seq_rm (ctx_tgt, 0 , n_past, -1 );
201211 }
202212
203- prompt_dft .push_back (id_last);
204- prompt_dft .insert (prompt_dft .end (), ids.begin (), ids.end () - 1 );
213+ prompt_tgt .push_back (id_last);
214+ prompt_tgt .insert (prompt_tgt .end (), ids.begin (), ids.end () - 1 );
205215
206216 // remember the last accepted token for the next iteration
207217 id_last = id;
@@ -210,6 +220,8 @@ int main(int argc, char ** argv) {
210220
211221 auto t_dec_end = ggml_time_us ();
212222
223+ const int n_input = inp.size ();
224+
213225 LOG (" \n\n " );
214226
215227 LOG_INF (" encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n " , n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size () / ((t_enc_end - t_enc_start) / 1e6f));
0 commit comments