Skip to content

Commit dfb84f6

Browse files
committed
implement chat_history
1 parent 47d893c commit dfb84f6

File tree

1 file changed

+29
-10
lines changed

1 file changed

+29
-10
lines changed

tools/mtmd/mtmd-cli.cpp

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,16 @@ struct mtmd_cli_context {
7676

7777
mtmd::bitmaps bitmaps;
7878

79-
// note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
80-
// so here we don't need to keep track of chat history
79+
// chat template
8180
common_chat_templates_ptr tmpls;
81+
std::vector<common_chat_msg> chat_history;
82+
bool use_jinja = false;
8283

8384
// support for legacy templates (models not having EOT token)
8485
llama_tokens antiprompt_tokens;
8586

8687
int n_threads = 1;
8788
llama_pos n_past = 0;
88-
bool use_jinja = false;
8989

9090
mtmd_cli_context(common_params & params) : llama_init(common_init_from_params(params)) {
9191
model = llama_init.model.get();
@@ -110,6 +110,12 @@ struct mtmd_cli_context {
110110

111111
tmpls = common_chat_templates_init(model, params.chat_template);
112112
use_jinja = params.use_jinja;
113+
if (!params.system_prompt.empty()) {
114+
common_chat_msg sys_msg;
115+
sys_msg.role = "system";
116+
sys_msg.content = params.system_prompt;
117+
chat_history.push_back(std::move(sys_msg));
118+
}
113119
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(tmpls.get(), params.use_jinja, params.default_template_kwargs).c_str());
114120

115121
init_vision_context(params);
@@ -195,19 +201,32 @@ static int generate_response(mtmd_cli_context & ctx, int n_predict) {
195201
return 1;
196202
}
197203
}
204+
205+
std::string generated_text = common_detokenize(ctx.lctx, generated_tokens);
206+
common_chat_msg msg;
207+
msg.role = "assistant";
208+
msg.content = generated_text;
209+
ctx.chat_history.push_back(std::move(msg));
210+
198211
return 0;
199212
}
200213

214+
static std::string chat_add_and_format(mtmd_cli_context & ctx, common_chat_msg & new_msg) {
215+
LOG_DBG("chat_add_and_format: new_msg.role='%s', new_msg.content='%s'\n",
216+
new_msg.role.c_str(), new_msg.content.c_str());
217+
auto formatted = common_chat_format_single(ctx.tmpls.get(), ctx.chat_history,
218+
new_msg, new_msg.role == "user",
219+
ctx.use_jinja);
220+
ctx.chat_history.push_back(new_msg);
221+
return formatted;
222+
}
223+
201224
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_bos = false) {
202-
common_chat_templates_inputs tmpl_inputs;
203-
tmpl_inputs.messages = {msg};
204-
tmpl_inputs.add_generation_prompt = true;
205-
tmpl_inputs.use_jinja = ctx.use_jinja;
206-
auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs);
207-
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());
225+
auto formatted_chat = chat_add_and_format(ctx, msg);
226+
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.c_str());
208227

209228
mtmd_input_text text;
210-
text.text = formatted_chat.prompt.c_str();
229+
text.text = formatted_chat.c_str();
211230
text.add_special = add_bos;
212231
text.parse_special = true;
213232

0 commit comments

Comments
 (0)