@@ -120,7 +120,6 @@ int main(int argc, char ** argv) {
120120 }
121121 }
122122
123-
124123 // Tokenize the prompt
125124 std::vector<llama_token> inp;
126125 inp = common_tokenize (ctx_tgt, params.prompt , true , true );
@@ -139,18 +138,6 @@ int main(int argc, char ** argv) {
139138 LOG (" %s" , common_token_to_piece (ctx_tgt, id).c_str ());
140139 }
141140
142- const int n_input = inp.size ();
143-
144- const auto t_enc_start = ggml_time_us ();
145-
146- // eval the prompt
147- llama_decode (ctx_tgt, llama_batch_get_one (inp.data (), n_input - 1 ));
148-
149- // note: keep the last token separate!
150- llama_token id_last = inp.back ();
151-
152- int n_past = inp.size () - 1 ;
153-
154141 // how many tokens to draft each time
155142 int n_draft = params.n_draft ;
156143
@@ -161,9 +148,25 @@ int main(int argc, char ** argv) {
161148 // used to determine end of generation
162149 bool has_eos = false ;
163150
151+ // ================================================
152+ // everything until here is standard initialization
153+ // the relevant stuff for speculative decoding starts here
154+
155+ const int n_input = inp.size ();
156+
157+ const auto t_enc_start = ggml_time_us ();
158+
164159 // target model sampling context
165160 struct common_sampler * smpl = common_sampler_init (model_tgt, params.sparams );
166161
162+ // eval the prompt
163+ llama_decode (ctx_tgt, llama_batch_get_one (inp.data (), n_input - 1 ));
164+
165+ // note: keep the last token separate!
166+ llama_token id_last = inp.back ();
167+
168+ int n_past = inp.size () - 1 ;
169+
167170 // init the speculator
168171 struct common_speculative_params params_spec;
169172 params_spec.n_draft = n_draft;
@@ -174,6 +177,13 @@ int main(int argc, char ** argv) {
174177 struct common_speculative * spec = common_speculative_init (params_spec);
175178
176179 // feed the prompt to the speculator
180+ //
181+ // this has to be kept synchronized with the target context
182+ //
183+ // TODO: simplify this by moving the context management logic in the common_speculative instance
184+ // for example, the common_speculative_add_draft can pass the entire context (or part of it) and the
185+ // speculator will automatically compute any new tokens that are not present in its context
186+ //
177187 common_speculative_set_prompt (spec, inp.data (), n_input - 1 );
178188
179189 llama_batch batch_tgt = llama_batch_init (llama_n_batch (ctx_tgt), 0 , 1 );
@@ -188,23 +198,40 @@ int main(int argc, char ** argv) {
188198 common_batch_add (batch_tgt, id_last, n_past, { 0 }, true );
189199
190200 // optionally, append draft tokens to the target batch
201+ //
202+ // this is the most important part of the speculation. the more probable tokens that are provided here
203+ // the better the performance will be. in theory, this computation can be performed asynchronously and even
204+ // offloaded to a remote device.
205+ //
191206 common_speculative_add_draft (spec, batch_tgt, id_last, n_past);
192207
193- // evaluate the target model on the drafted tokens
208+ // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
194209 {
195210 // LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
196211
197212 llama_decode (ctx_tgt, batch_tgt);
198213 }
199214
200- // process the full target batch and return the accepted token based on the target sampler
215+ // sample from the full target batch and return the accepted tokens based on the target sampler
216+ //
217+ // for each token to be accepted, the sampler would have to sample that same token
218+ // in such cases, instead of decoding the sampled token as we normally do, we simply continue with the
219+ // available logits from the batch and sample the next token until we run out of logits or the sampler
220+ // disagrees with the draft
221+ //
201222 const auto ids = common_speculative_sample (spec, smpl, ctx_tgt);
202223
224+ GGML_ASSERT (ids.size () > 0 ); // there will always be at least one accepted token
225+
203226 n_past += ids.size ();
204227 n_drafted += batch_tgt.n_tokens - 1 ;
205228 n_accept += ids.size () - 1 ;
206229
207230 // process the accepted tokens and update contexts
231+ //
232+ // this the standard token processing that we normally do
233+ // in this case, we do it for a group of accepted tokens at once
234+ //
208235 {
209236 llama_token id;
210237 std::string token_str;
@@ -232,7 +259,7 @@ int main(int argc, char ** argv) {
232259 break ;
233260 }
234261
235- LOG_DBG (" the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens \n " , id, token_str.c_str ());
262+ LOG_DBG (" accepted %d draft tokens, the last target token is: (%d, '%s')\n " , ( int ) ids. size () - 1 , id, token_str.c_str ());
236263
237264 {
238265 LOG_DBG (" clear kv cache from any extra tokens, n_past = %d\n " , n_past);
@@ -241,6 +268,7 @@ int main(int argc, char ** argv) {
241268 llama_kv_cache_seq_rm (ctx_dft, 0 , n_past, -1 );
242269 }
243270
271+ // remember the last accepted token for the next iteration
244272 id_last = id;
245273 }
246274 }
0 commit comments