@@ -76,9 +76,11 @@ struct mtmd_cli_context {
7676
7777 mtmd::bitmaps bitmaps;
7878
79- // note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
80- // so here we don't need to keep track of chat history
79+ // chat template
8180 common_chat_templates_ptr tmpls;
81+ std::vector<common_chat_msg> chat_history;
82+ bool use_jinja = false ;
83+ // TODO: support for --system-prompt with /clear command
8284
8385 // support for legacy templates (models not having EOT token)
8486 llama_tokens antiprompt_tokens;
@@ -108,6 +110,8 @@ struct mtmd_cli_context {
108110 }
109111
110112 tmpls = common_chat_templates_init (model, params.chat_template );
113+ use_jinja = params.use_jinja ;
114+ chat_history.clear ();
111115 LOG_INF (" %s: chat template example:\n %s\n " , __func__, common_chat_format_example (tmpls.get (), params.use_jinja , params.default_template_kwargs ).c_str ());
112116
113117 init_vision_context (params);
@@ -193,19 +197,33 @@ static int generate_response(mtmd_cli_context & ctx, int n_predict) {
193197 return 1 ;
194198 }
195199 }
200+
201+ std::string generated_text = common_detokenize (ctx.lctx , generated_tokens);
202+ common_chat_msg msg;
203+ msg.role = " assistant" ;
204+ msg.content = generated_text;
205+ ctx.chat_history .push_back (std::move (msg));
206+
196207 return 0 ;
197208}
198209
199- static int eval_message (mtmd_cli_context & ctx, common_chat_msg & msg, bool add_bos = false ) {
200- common_chat_templates_inputs tmpl_inputs;
201- tmpl_inputs.messages = {msg};
202- tmpl_inputs.add_generation_prompt = true ;
203- tmpl_inputs.use_jinja = false ; // jinja is buggy here
204- auto formatted_chat = common_chat_templates_apply (ctx.tmpls .get (), tmpl_inputs);
205- LOG_DBG (" formatted_chat.prompt: %s\n " , formatted_chat.prompt .c_str ());
210+ static std::string chat_add_and_format (mtmd_cli_context & ctx, common_chat_msg & new_msg) {
211+ LOG_DBG (" chat_add_and_format: new_msg.role='%s', new_msg.content='%s'\n " ,
212+ new_msg.role .c_str (), new_msg.content .c_str ());
213+ auto formatted = common_chat_format_single (ctx.tmpls .get (), ctx.chat_history ,
214+ new_msg, new_msg.role == " user" ,
215+ ctx.use_jinja );
216+ ctx.chat_history .push_back (new_msg);
217+ return formatted;
218+ }
219+
220+ static int eval_message (mtmd_cli_context & ctx, common_chat_msg & msg) {
221+ bool add_bos = ctx.chat_history .empty ();
222+ auto formatted_chat = chat_add_and_format (ctx, msg);
223+ LOG_DBG (" formatted_chat.prompt: %s\n " , formatted_chat.c_str ());
206224
207225 mtmd_input_text text;
208- text.text = formatted_chat.prompt . c_str ();
226+ text.text = formatted_chat.c_str ();
209227 text.add_special = add_bos;
210228 text.parse_special = true ;
211229
@@ -303,7 +321,7 @@ int main(int argc, char ** argv) {
303321 return 1 ; // error is already printed by libmtmd
304322 }
305323 }
306- if (eval_message (ctx, msg, true )) {
324+ if (eval_message (ctx, msg)) {
307325 return 1 ;
308326 }
309327 if (!g_is_interrupted && generate_response (ctx, n_predict)) {
@@ -322,7 +340,6 @@ int main(int argc, char ** argv) {
322340 LOG (" \n /quit or /exit exit the program" );
323341 LOG (" \n " );
324342
325- bool is_first_msg = true ;
326343 std::string content;
327344
328345 while (!g_is_interrupted) {
@@ -342,7 +359,8 @@ int main(int argc, char ** argv) {
342359 }
343360 if (line == " /clear" ) {
344361 ctx.n_past = 0 ;
345- llama_memory_seq_rm (llama_get_memory (ctx.lctx ), 0 , 1 , -1 ); // keep BOS
362+ ctx.chat_history .clear ();
363+ llama_memory_clear (llama_get_memory (ctx.lctx ), true );
346364 LOG (" Chat history cleared\n\n " );
347365 continue ;
348366 }
@@ -367,7 +385,7 @@ int main(int argc, char ** argv) {
367385 common_chat_msg msg;
368386 msg.role = " user" ;
369387 msg.content = content;
370- int ret = eval_message (ctx, msg, is_first_msg );
388+ int ret = eval_message (ctx, msg);
371389 if (ret) {
372390 return 1 ;
373391 }
@@ -376,7 +394,6 @@ int main(int argc, char ** argv) {
376394 return 1 ;
377395 }
378396 content.clear ();
379- is_first_msg = false ;
380397 }
381398 }
382399 if (g_is_interrupted) LOG (" \n Interrupted by user\n " );
0 commit comments