@@ -13,7 +13,7 @@ struct common_speculative {
1313 struct llama_context * ctx;
1414 struct common_sampler * smpl;
1515
16- llama_batch batch;
16+ llama_batch_ext_ptr batch;
1717 llama_tokens prompt;
1818};
1919
@@ -22,7 +22,7 @@ struct common_speculative * common_speculative_init(
2222 auto * result = new common_speculative {
2323 /* .ctx = */ ctx_dft,
2424 /* .smpl = */ nullptr ,
25- /* .batch = */ llama_batch_init ( llama_n_batch (ctx_dft), 0 , 1 ),
25+ /* .batch = */ llama_batch_ext_ptr ( llama_batch_ext_init ( llama_n_batch (ctx_dft), 1 ) ),
2626 /* .prompt = */ {},
2727 };
2828
@@ -68,8 +68,6 @@ void common_speculative_free(struct common_speculative * spec) {
6868
6969 common_sampler_free (spec->smpl );
7070
71- llama_batch_free (spec->batch );
72-
7371 delete spec;
7472}
7573
@@ -150,6 +148,8 @@ llama_tokens common_speculative_gen_draft(
150148
151149 const int i_start = std::max<int >(0 , (int ) prompt_tgt.size () - n_ctx);
152150
151+ const llama_seq_id seq_id = 0 ;
152+
153153 // reuse as much as possible from the old draft context
154154 // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
155155 for (int i = 0 ; i < (int ) prompt.size (); ++i) {
@@ -205,40 +205,40 @@ llama_tokens common_speculative_gen_draft(
205205 }
206206
207207 // prepare a batch to evaluate any new tokens in the prompt
208- common_batch_clear (batch);
208+ llama_batch_ext_clear (batch. get () );
209209
210210 for (size_t i = i_start + reuse_n; i < prompt_tgt.size (); ++i) {
211211 // LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
212- common_batch_add (batch, prompt_tgt[i], i - i_start, { 0 } , false );
212+ llama_batch_ext_add_text_token (batch. get () , prompt_tgt[i], i - i_start, &seq_id, 1 , false );
213213
214214 prompt.push_back (prompt_tgt[i]);
215215 }
216216
217217 // we should rarely end-up here during normal decoding
218- if (batch.n_tokens > 0 ) {
218+ if (llama_batch_ext_get_n_tokens ( batch.get ()) > 0 ) {
219219 // LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
220220
221- llama_decode (ctx, batch);
221+ llama_decode_ext (ctx, batch. get () );
222222 }
223223
224224 const llama_pos n_past = prompt.size ();
225225
226226 LOG_DBG (" %s: n_past = %d\n " , __func__, n_past);
227227
228- common_batch_clear (batch);
229- common_batch_add (batch, id_last, n_past, { 0 } , true );
228+ llama_batch_ext_clear (batch. get () );
229+ llama_batch_ext_add_text_token (batch. get () , id_last, n_past, &seq_id, 1 , true );
230230
231231 prompt.push_back (id_last);
232232
233233 // LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
234234
235- llama_decode (ctx, batch);
235+ llama_decode_ext (ctx, batch. get () );
236236
237237 common_sampler_reset (smpl);
238238
239239 // sample n_draft tokens from the draft model
240240 for (int i = 0 ; i < params.n_draft ; ++i) {
241- common_batch_clear (batch);
241+ llama_batch_ext_clear (batch. get () );
242242
243243 common_sampler_sample (smpl, ctx, 0 , true );
244244
@@ -265,10 +265,10 @@ llama_tokens common_speculative_gen_draft(
265265 break ;
266266 }
267267
268- common_batch_add (batch, id, n_past + i + 1 , { 0 } , true );
268+ llama_batch_ext_add_text_token (batch. get () , id, n_past + i + 1 , &seq_id, 1 , true );
269269
270270 // evaluate the drafted tokens on the draft model
271- llama_decode (ctx, batch);
271+ llama_decode_ext (ctx, batch. get () );
272272
273273 prompt.push_back (id);
274274 }
0 commit comments