55#include " llama.h"
66#include " ggml.h"
77#include " console.h"
8+ #include " chat.h"
89#include " llava2.h"
910
1011#include < vector>
@@ -56,13 +57,18 @@ static void sigint_handler(int signo) {
5657#endif
5758
5859struct gemma3_context {
59- llava2_context_ptr ctx_llava2 ;
60+ llava2_context_ptr ctx_vision ;
6061 common_init_result llama_init;
6162
6263 llama_model * model;
6364 llama_context * lctx;
6465 const llama_vocab * vocab;
6566 llama_batch batch;
67+ int n_batch;
68+
69+ // note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
70+ // so here we don't need to keep track of chat history
71+ common_chat_templates_ptr tmpls;
6672
6773 int n_threads = 1 ;
6874 llama_pos n_past = 0 ;
@@ -73,18 +79,20 @@ struct gemma3_context {
7379 vocab = llama_model_get_vocab (model);
7480 n_threads = params.cpuparams .n_threads ;
7581 batch = llama_batch_init (params.n_batch , 0 , 1 );
76- init_clip_model (params);
82+ n_batch = params.n_batch ;
83+ tmpls = common_chat_templates_init (model, params.chat_template );
84+ init_vision_context (params);
7785 }
7886
79- void init_clip_model (common_params & params) {
87+ void init_vision_context (common_params & params) {
8088 const char * clip_path = params.mmproj .path .c_str ();
81- ctx_llava2 = llava2_init_from_file (clip_path, model, llava2_context_params{
89+ ctx_vision = llava2_init_from_file (clip_path, model, llava2_context_params{
8290 /* use_gpu */ true ,
8391 /* n_threads */ params.cpuparams .n_threads ,
8492 /* verbosity */ GGML_LOG_LEVEL_INFO,
8593 });
86- if (!ctx_llava2 .get ()) {
87- LOG_ERR (" Failed to load CLIP model from %s\n " , clip_path);
94+ if (!ctx_vision .get ()) {
95+ LOG_ERR (" Failed to load vision model from %s\n " , clip_path);
8896 exit (1 );
8997 }
9098 }
@@ -123,77 +131,6 @@ struct decode_embd_batch {
123131 }
124132};
125133
126- static int eval_text (gemma3_context & ctx, std::string input, bool logits_last = false ) {
127- llama_tokens tokens = common_tokenize (ctx.lctx , input, false , true );
128- common_batch_clear (ctx.batch );
129- for (llama_token & t : tokens) {
130- common_batch_add (ctx.batch , t, ctx.n_past ++, {0 }, false );
131- }
132- if (logits_last) {
133- ctx.batch .logits [ctx.batch .n_tokens - 1 ] = true ;
134- }
135- // LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str());
136- if (llama_decode (ctx.lctx , ctx.batch )) {
137- LOG_ERR (" Failed to decode text\n " );
138- return 1 ;
139- }
140- return 0 ;
141- }
142-
143- static int eval_image (gemma3_context & ctx, std::string & fname) {
144- std::vector<float > image_embd_v;
145- int n_embd = llama_model_n_embd (ctx.model );
146- int n_tokens = 256 ;
147- image_embd_v.resize (n_tokens * n_embd);
148-
149- bool ok;
150- struct clip_image_u8 * img_u8 = clip_image_u8_init ();
151- ok = clip_image_load_from_file (fname.c_str (), img_u8);
152- if (!ok) {
153- LOG_ERR (" Unable to load image %s\n " , fname.c_str ());
154- clip_image_u8_free (img_u8);
155- return 2 ; // non-fatal error
156- }
157-
158- clip_image_f32_batch batch_f32;
159- ok = clip_image_preprocess (ctx.ctx_clip , img_u8, &batch_f32);
160- if (!ok) {
161- LOG_ERR (" Unable to preprocess image\n " );
162- clip_image_f32_batch_free (&batch_f32);
163- clip_image_u8_free (img_u8);
164- return 1 ;
165- }
166-
167- int64_t t0 = ggml_time_ms ();
168- LOG (" Encoding image %s\n " , fname.c_str ());
169- ok = clip_image_batch_encode (ctx.ctx_clip , ctx.n_threads , &batch_f32, image_embd_v.data ());
170- if (!ok) {
171- LOG_ERR (" Unable to encode image\n " );
172- clip_image_f32_batch_free (&batch_f32);
173- clip_image_u8_free (img_u8);
174- return 1 ;
175- }
176- LOG (" Image encoded in %" PRId64 " ms\n " , ggml_time_ms () - t0);
177-
178- clip_image_f32_batch_free (&batch_f32);
179- clip_image_u8_free (img_u8);
180-
181- // decode image embeddings
182- int64_t t1 = ggml_time_ms ();
183- eval_text (ctx, " <start_of_image>" );
184- llama_set_causal_attn (ctx.lctx , false );
185- decode_embd_batch batch_img (image_embd_v.data (), n_tokens, ctx.n_past , 0 );
186- if (llama_decode (ctx.lctx , batch_img.batch )) {
187- LOG_ERR (" failed to decode image\n " );
188- return 1 ;
189- }
190- ctx.n_past += n_tokens;
191- llama_set_causal_attn (ctx.lctx , true );
192- eval_text (ctx, " <end_of_image>" );
193- LOG (" Image decoded in %" PRId64 " ms\n " , ggml_time_ms () - t1);
194- return 0 ;
195- }
196-
197134static int generate_response (gemma3_context & ctx, common_sampler * smpl, int n_predict) {
198135 for (int i = 0 ; i < n_predict; i++) {
199136 if (i > n_predict || !g_is_generating) {
@@ -223,6 +160,41 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_
223160 return 0 ;
224161}
225162
163+ static int eval_message (gemma3_context & ctx, common_chat_msg & msg, std::vector<std::string> & images_fname, bool add_bos = false ) {
164+ std::vector<llava2_bitmap> bitmaps;
165+
166+ common_chat_templates_inputs tmpl_inputs;
167+ tmpl_inputs.messages = {msg};
168+ tmpl_inputs.add_generation_prompt = true ;
169+ tmpl_inputs.use_jinja = false ; // jinja is buggy here
170+ auto formatted_chat = common_chat_templates_apply (ctx.tmpls .get (), tmpl_inputs);
171+ LOG_DBG (" formatted_chat.prompt: %s\n " , formatted_chat.prompt .c_str ());
172+
173+ for (auto & fname : images_fname) {
174+ llava2_bitmap bitmap;
175+ if (llava2_bitmap_init_from_file (fname.c_str (), bitmap)) {
176+ LOG_ERR (" Unable to load image %s\n " , fname.c_str ());
177+ return 2 ; // image not found
178+ }
179+ bitmaps.push_back (std::move (bitmap));
180+ }
181+
182+ std::vector<llava2_input_chunk> chunks;
183+ if (llava2_tokenize (ctx.ctx_vision , chunks, formatted_chat.prompt , add_bos, true , bitmaps)) {
184+ LOG_ERR (" Unable to tokenize prompt\n " );
185+ return 1 ;
186+ }
187+
188+ if (llava2_helper_eval (ctx.ctx_vision , ctx.lctx , chunks, ctx.n_past , 0 , ctx.n_batch )) {
189+ LOG_ERR (" Unable to eval prompt\n " );
190+ return 1 ;
191+ }
192+
193+ ctx.n_past += llava2_helper_get_n_tokens (chunks);
194+
195+ return 0 ;
196+ }
197+
226198int main (int argc, char ** argv) {
227199 ggml_time_init ();
228200
@@ -264,22 +236,15 @@ int main(int argc, char ** argv) {
264236#endif
265237 }
266238
267- if (eval_text (ctx, " <bos>" )) {
268- return 1 ;
269- }
270-
271239 if (is_single_turn) {
272240 g_is_generating = true ;
273- std::string prompt = " <start_of_turn>user\n <image>" + params.prompt + " <end_of_turn><start_of_turn>model\n " ;
274- if (eval_text (ctx, " <start_of_turn>user\n " )) {
275- return 1 ;
276- }
277- for (auto & fname : params.image ) {
278- if (eval_image (ctx, fname)) {
279- return 1 ;
280- }
241+ if (params.prompt .find (" <__image__>" ) == std::string::npos) {
242+ params.prompt += " <__image__>" ;
281243 }
282- if (eval_text (ctx, params.prompt + " <end_of_turn><start_of_turn>model\n " , true )) {
244+ common_chat_msg msg;
245+ msg.role = " user" ;
246+ msg.content = params.prompt ;
247+ if (eval_message (ctx, msg, params.image , true )) {
283248 return 1 ;
284249 }
285250 if (generate_response (ctx, smpl, n_predict)) {
@@ -293,9 +258,9 @@ int main(int argc, char ** argv) {
293258 LOG (" \n /quit or /exit exit the program" );
294259 LOG (" \n " );
295260
296- if ( eval_text (ctx, " <start_of_turn>user \n " )) {
297- return 1 ;
298- }
261+ bool is_first_msg = true ;
262+ std::vector<std::string> images_fname ;
263+ std::string content;
299264
300265 while (true ) {
301266 g_is_generating = false ;
@@ -320,24 +285,31 @@ int main(int argc, char ** argv) {
320285 g_is_generating = true ;
321286 if (line.find (" /image" ) == 0 ) {
322287 std::string image = line.substr (7 );
323- int res = eval_image (ctx, image);
324- if (res == 2 ) {
325- continue ; // image not found
326- }
327- if (res) {
328- return 1 ;
329- }
288+ images_fname.push_back (string_strip (image));
289+ content += " <__image__>" ;
330290 continue ;
291+ } else {
292+ content += line;
331293 }
332- if (eval_text (ctx, line + " <end_of_turn><start_of_turn>model\n " , true )) {
333- return 1 ;
294+ common_chat_msg msg;
295+ msg.role = " user" ;
296+ msg.content = content;
297+ int ret = eval_message (ctx, msg, images_fname, is_first_msg);
298+ if (ret == 2 ) {
299+ // non-fatal error
300+ images_fname.clear ();
301+ content.clear ();
302+ continue ;
334303 }
335- if (generate_response (ctx, smpl, n_predict) ) {
304+ if (ret ) {
336305 return 1 ;
337306 }
338- if (eval_text (ctx, " <end_of_turn><start_of_turn>user \n " )) {
307+ if (generate_response (ctx, smpl, n_predict )) {
339308 return 1 ;
340309 }
310+ images_fname.clear ();
311+ content.clear ();
312+ is_first_msg = false ;
341313 }
342314 }
343315
0 commit comments