@@ -160,8 +160,9 @@ struct mtmd_cli_context {
160160 }
161161};
162162
163- static int generate_response (mtmd_cli_context & ctx, int n_predict) {
163+ static std::string generate_response (mtmd_cli_context & ctx, int n_predict) {
164164 llama_tokens generated_tokens;
165+ std::string response = " " ;
165166 for (int i = 0 ; i < n_predict; i++) {
166167 if (i > n_predict || !g_is_generating || g_is_interrupted) {
167168 LOG (" \n " );
@@ -176,8 +177,9 @@ static int generate_response(mtmd_cli_context & ctx, int n_predict) {
176177 LOG (" \n " );
177178 break ; // end of generation
178179 }
179-
180- LOG (" %s" , common_token_to_piece (ctx.lctx , token_id).c_str ());
180+ std::string piece=common_token_to_piece (ctx.lctx , token_id);
181+ LOG (" %s" , piece.c_str ());
182+ response += piece;
181183 fflush (stdout);
182184
183185 if (g_is_interrupted) {
@@ -190,17 +192,18 @@ static int generate_response(mtmd_cli_context & ctx, int n_predict) {
190192 common_batch_add (ctx.batch , token_id, ctx.n_past ++, {0 }, true );
191193 if (llama_decode (ctx.lctx , ctx.batch )) {
192194 LOG_ERR (" failed to decode token\n " );
193- return 1 ;
195+ return " " ;
194196 }
195197 }
196- return 0 ;
198+ return response ;
197199}
198200
199- static int eval_message (mtmd_cli_context & ctx, common_chat_msg & msg , bool add_bos = false ) {
201+ static int eval_message (mtmd_cli_context & ctx, const std::vector< common_chat_msg> & messages , bool add_bos = false ) {
200202 common_chat_templates_inputs tmpl_inputs;
201- tmpl_inputs.messages = {msg} ;
203+ tmpl_inputs.messages = messages ;
202204 tmpl_inputs.add_generation_prompt = true ;
203- tmpl_inputs.use_jinja = false ; // jinja is buggy here
205+ tmpl_inputs.no_part_concat =true ;
206+ tmpl_inputs.use_jinja = true ; // jinja is bughigy here
204207 auto formatted_chat = common_chat_templates_apply (ctx.tmpls .get (), tmpl_inputs);
205208 LOG_DBG (" formatted_chat.prompt: %s\n " , formatted_chat.prompt .c_str ());
206209
@@ -303,10 +306,10 @@ int main(int argc, char ** argv) {
303306 return 1 ; // error is already printed by libmtmd
304307 }
305308 }
306- if (eval_message (ctx, msg, true )) {
309+ if (eval_message (ctx,{ msg} , true )) {
307310 return 1 ;
308311 }
309- if (!g_is_interrupted && generate_response (ctx, n_predict)) {
312+ if (!g_is_interrupted && generate_response (ctx, n_predict). empty () ) {
310313 return 1 ;
311314 }
312315
@@ -324,7 +327,7 @@ int main(int argc, char ** argv) {
324327
325328 bool is_first_msg = true ;
326329 std::string content;
327-
330+ std::vector<common_chat_msg> messages;
328331 while (!g_is_interrupted) {
329332 g_is_generating = false ;
330333 LOG (" \n > " );
@@ -357,24 +360,31 @@ int main(int argc, char ** argv) {
357360 std::string media_path = line.substr (7 );
358361 if (ctx.load_media (media_path)) {
359362 LOG (" %s %s loaded\n " , media_path.c_str (), is_image ? " image" : " audio" );
360- content += mtmd_default_marker ();
363+ // content += mtmd_default_marker();
364+ common_chat_msg msg;
365+ msg.content_parts .push_back ({" image" ," " });
366+ messages.push_back (std::move (msg));
361367 }
362368 // else, error is already printed by libmtmd
363369 continue ;
364- } else {
365- content += line;
366370 }
367371 common_chat_msg msg;
368372 msg.role = " user" ;
369- msg.content = content;
370- int ret = eval_message (ctx, msg, is_first_msg);
373+ msg.content = line;
374+ messages.push_back (std::move (msg));
375+ int ret = eval_message (ctx, messages, is_first_msg);
371376 if (ret) {
372377 return 1 ;
373378 }
374379 if (g_is_interrupted) break ;
375- if (generate_response (ctx, n_predict)) {
380+ auto response=generate_response (ctx, n_predict);
381+ if (response.empty ()) {
376382 return 1 ;
377383 }
384+ common_chat_msg response_message;
385+ response_message.role = " system" ;
386+ response_message.content = response;
387+ messages.push_back (response_message);
378388 content.clear ();
379389 is_first_msg = false ;
380390 }
0 commit comments