Skip to content

Commit 03a4a49

Browse files
committed
Merge branch 'master' into xsn/paddleocr
2 parents 030f1b2 + d0660f2 commit 03a4a49

File tree

1 file changed

+31
-15
lines changed

1 file changed

+31
-15
lines changed

tools/mtmd/mtmd-cli.cpp

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,11 @@ struct mtmd_cli_context {
7676

7777
mtmd::bitmaps bitmaps;
7878

79-
// note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
80-
// so here we don't need to keep track of chat history
79+
// chat template
8180
common_chat_templates_ptr tmpls;
81+
std::vector<common_chat_msg> chat_history;
82+
bool use_jinja = false;
83+
// TODO: support for --system-prompt with /clear command
8284

8385
// support for legacy templates (models not having EOT token)
8486
llama_tokens antiprompt_tokens;
@@ -110,6 +112,7 @@ struct mtmd_cli_context {
110112

111113
tmpls = common_chat_templates_init(model, params.chat_template);
112114
use_jinja = params.use_jinja;
115+
chat_history.clear();
113116
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(tmpls.get(), params.use_jinja, params.default_template_kwargs).c_str());
114117

115118
init_vision_context(params);
@@ -195,19 +198,33 @@ static int generate_response(mtmd_cli_context & ctx, int n_predict) {
195198
return 1;
196199
}
197200
}
201+
202+
std::string generated_text = common_detokenize(ctx.lctx, generated_tokens);
203+
common_chat_msg msg;
204+
msg.role = "assistant";
205+
msg.content = generated_text;
206+
ctx.chat_history.push_back(std::move(msg));
207+
198208
return 0;
199209
}
200210

201-
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_bos = false) {
202-
common_chat_templates_inputs tmpl_inputs;
203-
tmpl_inputs.messages = {msg};
204-
tmpl_inputs.add_generation_prompt = true;
205-
tmpl_inputs.use_jinja = ctx.use_jinja;
206-
auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs);
207-
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());
211+
static std::string chat_add_and_format(mtmd_cli_context & ctx, common_chat_msg & new_msg) {
212+
LOG_DBG("chat_add_and_format: new_msg.role='%s', new_msg.content='%s'\n",
213+
new_msg.role.c_str(), new_msg.content.c_str());
214+
auto formatted = common_chat_format_single(ctx.tmpls.get(), ctx.chat_history,
215+
new_msg, new_msg.role == "user",
216+
ctx.use_jinja);
217+
ctx.chat_history.push_back(new_msg);
218+
return formatted;
219+
}
220+
221+
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg) {
222+
bool add_bos = ctx.chat_history.empty();
223+
auto formatted_chat = chat_add_and_format(ctx, msg);
224+
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.c_str());
208225

209226
mtmd_input_text text;
210-
text.text = formatted_chat.prompt.c_str();
227+
text.text = formatted_chat.c_str();
211228
text.add_special = add_bos;
212229
text.parse_special = true;
213230

@@ -305,7 +322,7 @@ int main(int argc, char ** argv) {
305322
return 1; // error is already printed by libmtmd
306323
}
307324
}
308-
if (eval_message(ctx, msg, true)) {
325+
if (eval_message(ctx, msg)) {
309326
return 1;
310327
}
311328
if (!g_is_interrupted && generate_response(ctx, n_predict)) {
@@ -324,7 +341,6 @@ int main(int argc, char ** argv) {
324341
LOG("\n /quit or /exit exit the program");
325342
LOG("\n");
326343

327-
bool is_first_msg = true;
328344
std::string content;
329345

330346
while (!g_is_interrupted) {
@@ -344,7 +360,8 @@ int main(int argc, char ** argv) {
344360
}
345361
if (line == "/clear") {
346362
ctx.n_past = 0;
347-
llama_memory_seq_rm(llama_get_memory(ctx.lctx), 0, 1, -1); // keep BOS
363+
ctx.chat_history.clear();
364+
llama_memory_clear(llama_get_memory(ctx.lctx), true);
348365
LOG("Chat history cleared\n\n");
349366
continue;
350367
}
@@ -369,7 +386,7 @@ int main(int argc, char ** argv) {
369386
common_chat_msg msg;
370387
msg.role = "user";
371388
msg.content = content;
372-
int ret = eval_message(ctx, msg, is_first_msg);
389+
int ret = eval_message(ctx, msg);
373390
if (ret) {
374391
return 1;
375392
}
@@ -378,7 +395,6 @@ int main(int argc, char ** argv) {
378395
return 1;
379396
}
380397
content.clear();
381-
is_first_msg = false;
382398
}
383399
}
384400
if (g_is_interrupted) LOG("\nInterrupted by user\n");

0 commit comments

Comments
 (0)