@@ -142,6 +142,8 @@ llama_tokens common_speculative_gen_draft(
142142
143143 const int i_start = std::max<int >(0 , (int ) prompt_tgt.size () - n_ctx);
144144
145+ // reuse as much as possible from the old draft context
146+ // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
145147 for (int i = 0 ; i < (int ) prompt.size (); ++i) {
146148 int cur = 0 ;
147149 while (i_start + cur < (int ) prompt_tgt.size () &&
@@ -166,6 +168,8 @@ llama_tokens common_speculative_gen_draft(
166168
167169 prompt.clear ();
168170 } else {
171+ // this happens when a previous draft has been discarded (for example, due to being too small), but the
172+ // target model agreed with it. in this case, we simply pass back the previous results to save compute
169173 if (reuse_i + reuse_n < (int ) prompt.size () && prompt[reuse_i + reuse_n] == id_last) {
170174 for (int i = reuse_i + reuse_n + 1 ; i < (int ) prompt.size (); ++i) {
171175 result.push_back (prompt[i]);
@@ -174,42 +178,51 @@ llama_tokens common_speculative_gen_draft(
174178 break ;
175179 }
176180 }
181+
177182 return result;
178183 }
179184
180- llama_kv_cache_seq_rm (ctx, 0 , 0 , reuse_i);
181- llama_kv_cache_seq_rm (ctx, 0 , reuse_i + reuse_n, -1 );
182- llama_kv_cache_seq_add (ctx, 0 , reuse_i, -1 , -reuse_i);
185+ if (reuse_i > 0 ) {
186+ llama_kv_cache_seq_rm (ctx, 0 , 0 , reuse_i);
187+ llama_kv_cache_seq_add (ctx, 0 , reuse_i, -1 , -reuse_i);
188+
189+ prompt.erase (prompt.begin (), prompt.begin () + reuse_i);
190+ }
191+
192+ if (reuse_n < (int ) prompt.size ()) {
193+ llama_kv_cache_seq_rm (ctx, 0 , reuse_n, -1 );
183194
184- prompt.erase (prompt.begin (), prompt.begin () + reuse_i );
185- prompt. erase (prompt. begin () + reuse_n, prompt. end ());
195+ prompt.erase (prompt.begin () + reuse_n , prompt.end () );
196+ }
186197 }
187198
199+ // prepare a batch to evaluate any new tokens in the prompt
188200 common_batch_clear (batch);
189201
190- for (int i = i_start + reuse_n; i < ( int ) prompt_tgt.size (); ++i) {
202+ for (size_t i = i_start + reuse_n; i < prompt_tgt.size (); ++i) {
191203 // LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
192204 common_batch_add (batch, prompt_tgt[i], i - i_start, { 0 }, false );
193205
194206 prompt.push_back (prompt_tgt[i]);
195207 }
196208
197- const llama_pos n_past = prompt_tgt.size () - i_start;
198-
199- LOG_DBG (" %s: n_past = %d\n " , __func__, n_past);
200-
209+ // we should rarely end-up here during normal decoding
201210 if (batch.n_tokens > 0 ) {
202- LOG_DBG (" %s: draft batch: %s\n " , __func__, string_from (ctx, batch).c_str ());
211+ // LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
203212
204213 llama_decode (ctx, batch);
205214 }
206215
216+ const llama_pos n_past = prompt.size ();
217+
218+ LOG_DBG (" %s: n_past = %d\n " , __func__, n_past);
219+
207220 common_batch_clear (batch);
208221 common_batch_add (batch, id_last, n_past, { 0 }, true );
209222
210223 prompt.push_back (id_last);
211224
212- LOG_DBG (" %s: prompt_last : %s\n " , __func__, string_from (ctx, prompt).c_str ());
225+ // LOG_DBG("%s: draft prompt : %s\n", __func__, string_from(ctx, prompt).c_str());
213226
214227 llama_decode (ctx, batch);
215228
0 commit comments