@@ -27,7 +27,7 @@ struct common_speculative * common_speculative_init(
2727 };
2828
2929 // TODO: optimize or pass from outside?
30- #if 1
30+ #if 0
3131 {
3232 common_sampler_params sparams;
3333 sparams.no_perf = false;
@@ -156,13 +156,27 @@ llama_tokens common_speculative_gen_draft(
156156 }
157157 }
158158
159- LOG_DBG (" %s: reuse_i = %d, reuse_n = %d\n " , __func__, reuse_i, reuse_n);
159+ LOG_DBG (" %s: reuse_i = %d, reuse_n = %d, prompt = %d\n " , __func__, reuse_i, reuse_n, (int ) prompt.size ());
160+
161+ llama_tokens result;
162+ result.reserve (params.n_draft );
160163
161164 if (reuse_n == 0 ) {
162165 llama_kv_cache_clear (ctx);
163166
164167 prompt.clear ();
165168 } else {
169+ if (reuse_i + reuse_n < (int ) prompt.size () && prompt[reuse_i + reuse_n] == id_last) {
170+ for (int i = reuse_i + reuse_n + 1 ; i < (int ) prompt.size (); ++i) {
171+ result.push_back (prompt[i]);
172+
173+ if (result.size () >= params.n_draft ) {
174+ break ;
175+ }
176+ }
177+ return result;
178+ }
179+
166180 llama_kv_cache_seq_rm (ctx, 0 , 0 , reuse_i);
167181 llama_kv_cache_seq_rm (ctx, 0 , reuse_i + reuse_n, -1 );
168182 llama_kv_cache_seq_add (ctx, 0 , reuse_i, -1 , -reuse_i);
@@ -201,9 +215,6 @@ llama_tokens common_speculative_gen_draft(
201215
202216 common_sampler_reset (smpl);
203217
204- llama_tokens result;
205- result.reserve (params.n_draft );
206-
207218 // sample n_draft tokens from the draft model
208219 for (int i = 0 ; i < params.n_draft ; ++i) {
209220 common_batch_clear (batch);
0 commit comments