@@ -48,11 +48,11 @@ int main(int argc, char ** argv) {
4848 auto tokens = common_tokenize (ctx, params.prompt , true );
4949
5050 // prepare the batch
51- llama_batch_ext * batch = llama_batch_ext_init_from_text (tokens.data (), tokens.size (), 0 , 0 , true );
51+ auto batch = llama_batch_ext_ptr::init_from_text (tokens.data (), tokens.size (), 0 , 0 , true );
5252
5353 // evaluate prompt
54- llama_decode_ext (ctx, batch);
55- n_past += llama_batch_ext_get_n_tokens (batch);
54+ llama_decode_ext (ctx, batch. get () );
55+ n_past += llama_batch_ext_get_n_tokens (batch. get () );
5656
5757 // save state (rng, logits, embedding and kv_cache) to file
5858 {
@@ -79,13 +79,13 @@ int main(int argc, char ** argv) {
7979 printf (" %s" , next_token_str.c_str ());
8080 result0 += next_token_str;
8181
82- llama_batch_ext_clear (batch);
82+ llama_batch_ext_clear (batch. get () );
8383 llama_seq_id seq_id = 0 ;
84- llama_batch_ext_add_text (batch, next_token, 0 , &seq_id, 1 , true );
84+ llama_batch_ext_add_text (batch. get () , next_token, 0 , &seq_id, 1 , true );
8585
86- if (llama_decode_ext (ctx, batch)) {
86+ if (llama_decode_ext (ctx, batch. get () )) {
8787 fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
88- llama_batch_ext_free (batch);
88+ llama_batch_ext_free (batch. get () );
8989 return 1 ;
9090 }
9191 n_past += 1 ;
@@ -132,13 +132,13 @@ int main(int argc, char ** argv) {
132132 printf (" %s" , next_token_str.c_str ());
133133 result1 += next_token_str;
134134
135- llama_batch_ext_clear (batch);
135+ llama_batch_ext_clear (batch. get () );
136136 llama_seq_id seq_id = 0 ;
137- llama_batch_ext_add_text (batch, next_token, 0 , &seq_id, 1 , true );
137+ llama_batch_ext_add_text (batch. get () , next_token, 0 , &seq_id, 1 , true );
138138
139- if (llama_decode_ext (ctx2, batch)) {
139+ if (llama_decode_ext (ctx2, batch. get () )) {
140140 fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
141- llama_batch_ext_free (batch);
141+ llama_batch_ext_free (batch. get () );
142142 return 1 ;
143143 }
144144 n_past += 1 ;
@@ -214,13 +214,13 @@ int main(int argc, char ** argv) {
214214 printf (" %s" , next_token_str.c_str ());
215215 result2 += next_token_str;
216216
217- llama_batch_ext_clear (batch);
217+ llama_batch_ext_clear (batch. get () );
218218 llama_seq_id seq_id = 1 ; // seq 1 instead of 0
219- llama_batch_ext_add_text (batch, next_token, 0 , &seq_id, 1 , true );
219+ llama_batch_ext_add_text (batch. get () , next_token, 0 , &seq_id, 1 , true );
220220
221- if (llama_decode_ext (ctx3, batch)) {
221+ if (llama_decode_ext (ctx3, batch. get () )) {
222222 fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
223- llama_batch_ext_free (batch);
223+ llama_batch_ext_free (batch. get () );
224224 return 1 ;
225225 }
226226 n_past += 1 ;
@@ -232,7 +232,7 @@ int main(int argc, char ** argv) {
232232 llama_sampler_free (smpl2);
233233 llama_sampler_free (smpl3);
234234
235- llama_batch_ext_free (batch);
235+ llama_batch_ext_free (batch. get () );
236236
237237 if (result0 != result2) {
238238 fprintf (stderr, " \n %s : error : the seq restore generation is different\n " , __func__);
0 commit comments