@@ -76,9 +76,11 @@ struct mtmd_cli_context {
7676
7777 mtmd::bitmaps bitmaps;
7878
79- // note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
80- // so here we don't need to keep track of chat history
79+ // chat template
8180 common_chat_templates_ptr tmpls;
81+ std::vector<common_chat_msg> chat_history;
82+ bool use_jinja = false ;
83+ // TODO: support for --system-prompt with /clear command
8284
8385 // support for legacy templates (models not having EOT token)
8486 llama_tokens antiprompt_tokens;
@@ -110,6 +112,7 @@ struct mtmd_cli_context {
110112
111113 tmpls = common_chat_templates_init (model, params.chat_template );
112114 use_jinja = params.use_jinja ;
115+ chat_history.clear ();
113116 LOG_INF (" %s: chat template example:\n %s\n " , __func__, common_chat_format_example (tmpls.get (), params.use_jinja , params.default_template_kwargs ).c_str ());
114117
115118 init_vision_context (params);
@@ -195,19 +198,33 @@ static int generate_response(mtmd_cli_context & ctx, int n_predict) {
195198 return 1 ;
196199 }
197200 }
201+
202+ std::string generated_text = common_detokenize (ctx.lctx , generated_tokens);
203+ common_chat_msg msg;
204+ msg.role = " assistant" ;
205+ msg.content = generated_text;
206+ ctx.chat_history .push_back (std::move (msg));
207+
198208 return 0 ;
199209}
200210
201- static int eval_message (mtmd_cli_context & ctx, common_chat_msg & msg, bool add_bos = false ) {
202- common_chat_templates_inputs tmpl_inputs;
203- tmpl_inputs.messages = {msg};
204- tmpl_inputs.add_generation_prompt = true ;
205- tmpl_inputs.use_jinja = ctx.use_jinja ;
206- auto formatted_chat = common_chat_templates_apply (ctx.tmpls .get (), tmpl_inputs);
207- LOG_DBG (" formatted_chat.prompt: %s\n " , formatted_chat.prompt .c_str ());
211+ static std::string chat_add_and_format (mtmd_cli_context & ctx, common_chat_msg & new_msg) {
212+ LOG_DBG (" chat_add_and_format: new_msg.role='%s', new_msg.content='%s'\n " ,
213+ new_msg.role .c_str (), new_msg.content .c_str ());
214+ auto formatted = common_chat_format_single (ctx.tmpls .get (), ctx.chat_history ,
215+ new_msg, new_msg.role == " user" ,
216+ ctx.use_jinja );
217+ ctx.chat_history .push_back (new_msg);
218+ return formatted;
219+ }
220+
221+ static int eval_message (mtmd_cli_context & ctx, common_chat_msg & msg) {
222+ bool add_bos = ctx.chat_history .empty ();
223+ auto formatted_chat = chat_add_and_format (ctx, msg);
224+ LOG_DBG (" formatted_chat.prompt: %s\n " , formatted_chat.c_str ());
208225
209226 mtmd_input_text text;
210- text.text = formatted_chat.prompt . c_str ();
227+ text.text = formatted_chat.c_str ();
211228 text.add_special = add_bos;
212229 text.parse_special = true ;
213230
@@ -305,7 +322,7 @@ int main(int argc, char ** argv) {
305322 return 1 ; // error is already printed by libmtmd
306323 }
307324 }
308- if (eval_message (ctx, msg, true )) {
325+ if (eval_message (ctx, msg)) {
309326 return 1 ;
310327 }
311328 if (!g_is_interrupted && generate_response (ctx, n_predict)) {
@@ -324,7 +341,6 @@ int main(int argc, char ** argv) {
324341 LOG (" \n /quit or /exit exit the program" );
325342 LOG (" \n " );
326343
327- bool is_first_msg = true ;
328344 std::string content;
329345
330346 while (!g_is_interrupted) {
@@ -344,7 +360,8 @@ int main(int argc, char ** argv) {
344360 }
345361 if (line == " /clear" ) {
346362 ctx.n_past = 0 ;
347- llama_memory_seq_rm (llama_get_memory (ctx.lctx ), 0 , 1 , -1 ); // keep BOS
363+ ctx.chat_history .clear ();
364+ llama_memory_clear (llama_get_memory (ctx.lctx ), true );
348365 LOG (" Chat history cleared\n\n " );
349366 continue ;
350367 }
@@ -369,7 +386,7 @@ int main(int argc, char ** argv) {
369386 common_chat_msg msg;
370387 msg.role = " user" ;
371388 msg.content = content;
372- int ret = eval_message (ctx, msg, is_first_msg );
389+ int ret = eval_message (ctx, msg);
373390 if (ret) {
374391 return 1 ;
375392 }
@@ -378,7 +395,6 @@ int main(int argc, char ** argv) {
378395 return 1 ;
379396 }
380397 content.clear ();
381- is_first_msg = false ;
382398 }
383399 }
384400 if (g_is_interrupted) LOG (" \n Interrupted by user\n " );
0 commit comments