55#include " clip.h"
66#include " stb_image.h"
77#include " llama.h"
8+ #include " llama-cpp.h"
89#include " ggml.h"
910#include " console.h"
1011
@@ -63,7 +64,7 @@ struct gemma3_context {
6364 llama_model * model;
6465 llama_context * lctx;
6566 const llama_vocab * vocab;
66- llama_batch batch;
67+ llama_batch_ext_ptr batch;
6768
6869 int n_threads = 1 ;
6970 llama_pos n_past = 0 ;
@@ -73,7 +74,7 @@ struct gemma3_context {
7374 lctx = llama_init.context .get ();
7475 vocab = llama_model_get_vocab (model);
7576 n_threads = params.cpuparams .n_threads ;
76- batch = llama_batch_init ( params.n_batch , 0 , 1 );
77+ batch. reset ( llama_batch_ext_init ( params.n_batch , 1 ) );
7778 init_clip_model (params);
7879 }
7980
@@ -87,50 +88,18 @@ struct gemma3_context {
8788 }
8889};
8990
90- struct decode_embd_batch {
91- std::vector<llama_pos> pos;
92- std::vector<int32_t > n_seq_id;
93- std::vector<llama_seq_id> seq_id_0;
94- std::vector<llama_seq_id *> seq_ids;
95- std::vector<int8_t > logits;
96- llama_batch batch;
97- decode_embd_batch (float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
98- pos .resize (n_tokens);
99- n_seq_id.resize (n_tokens);
100- seq_ids .resize (n_tokens + 1 );
101- logits .resize (n_tokens);
102- seq_id_0.resize (1 );
103- seq_id_0[0 ] = seq_id;
104- seq_ids [n_tokens] = nullptr ;
105- batch = {
106- /* n_tokens =*/ n_tokens,
107- /* tokens =*/ nullptr ,
108- /* embd =*/ embd,
109- /* pos =*/ pos.data (),
110- /* n_seq_id =*/ n_seq_id.data (),
111- /* seq_id =*/ seq_ids.data (),
112- /* logits =*/ logits.data (),
113- };
114- for (int i = 0 ; i < n_tokens; i++) {
115- batch.pos [i] = pos_0 + i;
116- batch.n_seq_id [i] = 1 ;
117- batch.seq_id [i] = seq_id_0.data ();
118- batch.logits [i] = false ;
119- }
120- }
121- };
122-
12391static int eval_text (gemma3_context & ctx, std::string input, bool logits_last = false ) {
12492 llama_tokens tokens = common_tokenize (ctx.lctx , input, false , true );
125- common_batch_clear (ctx.batch );
93+ llama_batch_ext_clear (ctx.batch . get () );
12694 for (llama_token & t : tokens) {
127- common_batch_add (ctx.batch , t, ctx.n_past ++, {0 }, false );
95+ llama_seq_id seq_id = 0 ;
96+ llama_batch_ext_add_text (ctx.batch .get (), t, 0 , &seq_id, 1 , false );
12897 }
12998 if (logits_last) {
130- ctx.batch .logits [ctx. batch . n_tokens - 1 ] = true ;
99+ llama_batch_ext_set_output_last ( ctx.batch .get ()) ;
131100 }
132101 // LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str());
133- if (llama_decode (ctx.lctx , ctx.batch )) {
102+ if (llama_decode_ext (ctx.lctx , ctx.batch . get () )) {
134103 LOG_ERR (" Failed to decode text\n " );
135104 return 1 ;
136105 }
@@ -179,8 +148,8 @@ static int eval_image(gemma3_context & ctx, std::string & fname) {
179148 int64_t t1 = ggml_time_ms ();
180149 eval_text (ctx, " <start_of_image>" );
181150 llama_set_causal_attn (ctx.lctx , false );
182- decode_embd_batch batch_img (image_embd_v.data (), n_tokens, ctx.n_past , 0 );
183- if (llama_decode (ctx.lctx , batch_img.batch )) {
151+ llama_batch_ext_ptr batch_img (llama_batch_ext_init_from_embd ( image_embd_v.data (), n_tokens, ctx.n_past , 0 ) );
152+ if (llama_decode_ext (ctx.lctx , batch_img.get () )) {
184153 LOG_ERR (" failed to decode image\n " );
185154 return 1 ;
186155 }
@@ -210,9 +179,10 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_
210179 fflush (stdout);
211180
212181 // eval the token
213- common_batch_clear (ctx.batch );
214- common_batch_add (ctx.batch , token_id, ctx.n_past ++, {0 }, true );
215- if (llama_decode (ctx.lctx , ctx.batch )) {
182+ llama_batch_ext_clear (ctx.batch .get ());
183+ llama_seq_id seq_id = 0 ;
184+ llama_batch_ext_add_text (ctx.batch .get (), token_id, ctx.n_past ++, &seq_id, 1 , true );
185+ if (llama_decode_ext (ctx.lctx , ctx.batch .get ())) {
216186 LOG_ERR (" failed to decode token\n " );
217187 return 1 ;
218188 }
0 commit comments