Skip to content

Commit 0d26141

Browse files
committed
mtmd-cli : load image right away
1 parent 492f619 commit 0d26141

File tree

1 file changed

+27
-25
lines changed

1 file changed

+27
-25
lines changed

examples/llava/mtmd-cli.cpp

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ struct mtmd_cli_context {
7272
llama_batch batch;
7373
int n_batch;
7474

75+
std::vector<mtmd_bitmap> bitmaps;
76+
7577
// note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
7678
// so here we don't need to keep track of chat history
7779
common_chat_templates_ptr tmpls;
@@ -134,6 +136,15 @@ struct mtmd_cli_context {
134136
antiprompt_tokens.begin()
135137
);
136138
}
139+
140+
bool load_image(const std::string & fname) {
141+
mtmd_bitmap bitmap;
142+
if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) {
143+
return false;
144+
}
145+
bitmaps.push_back(std::move(bitmap));
146+
return true;
147+
}
137148
};
138149

139150
static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
@@ -172,25 +183,14 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
172183
return 0;
173184
}
174185

175-
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vector<std::string> & images_fname, bool add_bos = false) {
176-
std::vector<mtmd_bitmap> bitmaps;
177-
186+
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_bos = false) {
178187
common_chat_templates_inputs tmpl_inputs;
179188
tmpl_inputs.messages = {msg};
180189
tmpl_inputs.add_generation_prompt = true;
181190
tmpl_inputs.use_jinja = false; // jinja is buggy here
182191
auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs);
183192
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());
184193

185-
for (auto & fname : images_fname) {
186-
mtmd_bitmap bitmap;
187-
if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) {
188-
LOG_ERR("Unable to load image %s\n", fname.c_str());
189-
return 2; // image not found
190-
}
191-
bitmaps.push_back(std::move(bitmap));
192-
}
193-
194194
mtmd_input_text text;
195195
text.text = formatted_chat.prompt;
196196
text.add_special = add_bos;
@@ -199,12 +199,14 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
199199

200200
if (g_is_interrupted) return 0;
201201

202-
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps);
202+
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, ctx.bitmaps);
203203
if (res != 0) {
204204
LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
205205
return 1;
206206
}
207207

208+
ctx.bitmaps.clear();
209+
208210
if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) {
209211
LOG_ERR("Unable to eval prompt\n");
210212
return 1;
@@ -267,7 +269,12 @@ int main(int argc, char ** argv) {
267269
common_chat_msg msg;
268270
msg.role = "user";
269271
msg.content = params.prompt;
270-
if (eval_message(ctx, msg, params.image, true)) {
272+
for (const auto & image : params.image) {
273+
if (!ctx.load_image(image)) {
274+
return 1; // error is already printed by libmtmd
275+
}
276+
}
277+
if (eval_message(ctx, msg, true)) {
271278
return 1;
272279
}
273280
if (!g_is_interrupted && generate_response(ctx, smpl, n_predict)) {
@@ -282,7 +289,6 @@ int main(int argc, char ** argv) {
282289
LOG("\n");
283290

284291
bool is_first_msg = true;
285-
std::vector<std::string> images_fname;
286292
std::string content;
287293

288294
while (!g_is_interrupted) {
@@ -313,7 +319,10 @@ int main(int argc, char ** argv) {
313319
continue;
314320
}
315321
std::string image = line.substr(7);
316-
images_fname.push_back(string_strip(image));
322+
if (ctx.load_image(image)) {
323+
LOG("Image %s loaded\n", image.c_str());
324+
}
325+
// else, error is already printed by libmtmd
317326
content += "<__image__>";
318327
continue;
319328
} else {
@@ -322,21 +331,14 @@ int main(int argc, char ** argv) {
322331
common_chat_msg msg;
323332
msg.role = "user";
324333
msg.content = content;
325-
int ret = eval_message(ctx, msg, images_fname, is_first_msg);
326-
if (g_is_interrupted) break;
327-
if (ret == 2) {
328-
// non-fatal error
329-
images_fname.clear();
330-
content.clear();
331-
continue;
332-
}
334+
int ret = eval_message(ctx, msg, is_first_msg);
333335
if (ret) {
334336
return 1;
335337
}
338+
if (g_is_interrupted) break;
336339
if (generate_response(ctx, smpl, n_predict)) {
337340
return 1;
338341
}
339-
images_fname.clear();
340342
content.clear();
341343
is_first_msg = false;
342344
}

0 commit comments

Comments
 (0)