Skip to content

Commit d0660f2

Browse files
authored
mtmd-cli : allow using --jinja (#16718)
* mtmd-cli : allow using --jinja * support -sys * implement chat_history * fix clear memory * rm -sys support, added TODO
1 parent fe6a988 commit d0660f2

File tree

2 files changed

+33
-16
lines changed

2 files changed

+33
-16
lines changed

common/arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3435,7 +3435,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
34353435
[](common_params & params) {
34363436
params.use_jinja = true;
34373437
}
3438-
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA"));
3438+
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_JINJA"));
34393439
add_opt(common_arg(
34403440
{"--reasoning-format"}, "FORMAT",
34413441
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"

tools/mtmd/mtmd-cli.cpp

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,11 @@ struct mtmd_cli_context {
7676

7777
mtmd::bitmaps bitmaps;
7878

79-
// note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
80-
// so here we don't need to keep track of chat history
79+
// chat template
8180
common_chat_templates_ptr tmpls;
81+
std::vector<common_chat_msg> chat_history;
82+
bool use_jinja = false;
83+
// TODO: support for --system-prompt with /clear command
8284

8385
// support for legacy templates (models not having EOT token)
8486
llama_tokens antiprompt_tokens;
@@ -108,6 +110,8 @@ struct mtmd_cli_context {
108110
}
109111

110112
tmpls = common_chat_templates_init(model, params.chat_template);
113+
use_jinja = params.use_jinja;
114+
chat_history.clear();
111115
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(tmpls.get(), params.use_jinja, params.default_template_kwargs).c_str());
112116

113117
init_vision_context(params);
@@ -193,19 +197,33 @@ static int generate_response(mtmd_cli_context & ctx, int n_predict) {
193197
return 1;
194198
}
195199
}
200+
201+
std::string generated_text = common_detokenize(ctx.lctx, generated_tokens);
202+
common_chat_msg msg;
203+
msg.role = "assistant";
204+
msg.content = generated_text;
205+
ctx.chat_history.push_back(std::move(msg));
206+
196207
return 0;
197208
}
198209

199-
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_bos = false) {
200-
common_chat_templates_inputs tmpl_inputs;
201-
tmpl_inputs.messages = {msg};
202-
tmpl_inputs.add_generation_prompt = true;
203-
tmpl_inputs.use_jinja = false; // jinja is buggy here
204-
auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs);
205-
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());
210+
static std::string chat_add_and_format(mtmd_cli_context & ctx, common_chat_msg & new_msg) {
211+
LOG_DBG("chat_add_and_format: new_msg.role='%s', new_msg.content='%s'\n",
212+
new_msg.role.c_str(), new_msg.content.c_str());
213+
auto formatted = common_chat_format_single(ctx.tmpls.get(), ctx.chat_history,
214+
new_msg, new_msg.role == "user",
215+
ctx.use_jinja);
216+
ctx.chat_history.push_back(new_msg);
217+
return formatted;
218+
}
219+
220+
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg) {
221+
bool add_bos = ctx.chat_history.empty();
222+
auto formatted_chat = chat_add_and_format(ctx, msg);
223+
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.c_str());
206224

207225
mtmd_input_text text;
208-
text.text = formatted_chat.prompt.c_str();
226+
text.text = formatted_chat.c_str();
209227
text.add_special = add_bos;
210228
text.parse_special = true;
211229

@@ -303,7 +321,7 @@ int main(int argc, char ** argv) {
303321
return 1; // error is already printed by libmtmd
304322
}
305323
}
306-
if (eval_message(ctx, msg, true)) {
324+
if (eval_message(ctx, msg)) {
307325
return 1;
308326
}
309327
if (!g_is_interrupted && generate_response(ctx, n_predict)) {
@@ -322,7 +340,6 @@ int main(int argc, char ** argv) {
322340
LOG("\n /quit or /exit exit the program");
323341
LOG("\n");
324342

325-
bool is_first_msg = true;
326343
std::string content;
327344

328345
while (!g_is_interrupted) {
@@ -342,7 +359,8 @@ int main(int argc, char ** argv) {
342359
}
343360
if (line == "/clear") {
344361
ctx.n_past = 0;
345-
llama_memory_seq_rm(llama_get_memory(ctx.lctx), 0, 1, -1); // keep BOS
362+
ctx.chat_history.clear();
363+
llama_memory_clear(llama_get_memory(ctx.lctx), true);
346364
LOG("Chat history cleared\n\n");
347365
continue;
348366
}
@@ -367,7 +385,7 @@ int main(int argc, char ** argv) {
367385
common_chat_msg msg;
368386
msg.role = "user";
369387
msg.content = content;
370-
int ret = eval_message(ctx, msg, is_first_msg);
388+
int ret = eval_message(ctx, msg);
371389
if (ret) {
372390
return 1;
373391
}
@@ -376,7 +394,6 @@ int main(int argc, char ** argv) {
376394
return 1;
377395
}
378396
content.clear();
379-
is_first_msg = false;
380397
}
381398
}
382399
if (g_is_interrupted) LOG("\nInterrupted by user\n");

0 commit comments

Comments
 (0)