Skip to content

Commit 4737bd0

Browse files
committed
mtmd : merge llava-cli and gemma3-cli into single mtmd-cli
1 parent b9154ec commit 4737bd0

File tree

7 files changed

+175
-62
lines changed

7 files changed

+175
-62
lines changed

common/arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2726,7 +2726,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
27262726
[](common_params & params, const std::string & value) {
27272727
params.chat_template = value;
27282728
}
2729-
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
2729+
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_LLAVA}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
27302730
add_opt(common_arg(
27312731
{"--chat-template-file"}, "JINJA_TEMPLATE_FILE",
27322732
string_format(

examples/llava/CMakeLists.txt

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,8 @@ if(TARGET BUILD_INFO)
6161
add_dependencies(mtmd BUILD_INFO)
6262
endif()
6363

64-
set(TARGET llama-llava-cli)
65-
add_executable(${TARGET} llava-cli.cpp)
66-
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-cli)
67-
install(TARGETS ${TARGET} RUNTIME)
68-
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
69-
target_compile_features(${TARGET} PRIVATE cxx_std_17)
64+
add_executable(llama-llava-cli deprecation-warning.cpp)
65+
add_executable(llama-gemma3-cli deprecation-warning.cpp)
7066

7167
set(TARGET llama-minicpmv-cli)
7268
add_executable(${TARGET} minicpmv-cli.cpp)
@@ -82,9 +78,9 @@ install(TARGETS ${TARGET} RUNTIME)
8278
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
8379
target_compile_features(${TARGET} PRIVATE cxx_std_17)
8480

85-
set(TARGET llama-gemma3-cli)
86-
add_executable(${TARGET} gemma3-cli.cpp)
87-
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-gemma3-cli)
81+
set(TARGET llama-mtmd-cli)
82+
add_executable(${TARGET} mtmd-cli.cpp)
83+
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-mtmd-cli)
8884
install(TARGETS ${TARGET} RUNTIME)
8985
target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT})
9086
target_compile_features(${TARGET} PRIVATE cxx_std_17)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#include <cstdio>
2+
#include <string>
3+
4+
int main(int argc, char** argv) {
5+
std::string filename = "main";
6+
if (argc >= 1) {
7+
filename = argv[0];
8+
}
9+
10+
// Get only the program name from the full path
11+
size_t pos = filename.find_last_of("/\\");
12+
if (pos != std::string::npos) {
13+
filename = filename.substr(pos+1);
14+
}
15+
16+
fprintf(stdout, "\n");
17+
fprintf(stdout, "WARNING: The binary '%s' is deprecated.\n", filename.c_str());
18+
fprintf(stdout, "Please use 'llama-mtmd-cli' instead.\n");
19+
fprintf(stdout, "\n");
20+
21+
return EXIT_FAILURE;
22+
}

examples/llava/gemma3-cli.cpp renamed to examples/llava/mtmd-cli.cpp

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,16 @@ static bool g_is_generating = false;
2828

2929
/**
3030
* Please note that this is NOT a production-ready stuff.
31-
* It is a playground for trying Gemma 3 vision capabilities.
31+
* It is a playground for trying multimodal support in llama.cpp.
3232
* For contributors: please keep this code simple and easy to understand.
3333
*/
3434

3535
static void show_additional_info(int /*argc*/, char ** argv) {
3636
LOG(
37-
"Experimental CLI for using Gemma 3 vision model\n\n"
37+
"Experimental CLI for multimodal\n\n"
3838
"Usage: %s [options] -m <model> --mmproj <mmproj> --image <image> -p <prompt>\n\n"
3939
" -m and --mmproj are required\n"
40+
" -hf user/repo can replace both -m and --mmproj in most cases\n"
4041
" --image and -p are optional, if NOT provided, the CLI will run in chat mode\n",
4142
argv[0]
4243
);
@@ -56,7 +57,7 @@ static void sigint_handler(int signo) {
5657
}
5758
#endif
5859

59-
struct gemma3_context {
60+
struct mtmd_cli_context {
6061
mtmd_context_ptr ctx_vision;
6162
common_init_result llama_init;
6263

@@ -70,18 +71,31 @@ struct gemma3_context {
7071
// so here we don't need to keep track of chat history
7172
common_chat_templates_ptr tmpls;
7273

74+
// support for legacy templates (models not having EOT token)
75+
llama_tokens antiprompt_tokens;
76+
7377
int n_threads = 1;
7478
llama_pos n_past = 0;
7579

76-
gemma3_context(common_params & params) : llama_init(common_init_from_params(params)) {
80+
mtmd_cli_context(common_params & params) : llama_init(common_init_from_params(params)) {
7781
model = llama_init.model.get();
7882
lctx = llama_init.context.get();
7983
vocab = llama_model_get_vocab(model);
8084
n_threads = params.cpuparams.n_threads;
8185
batch = llama_batch_init(params.n_batch, 0, 1);
8286
n_batch = params.n_batch;
87+
8388
tmpls = common_chat_templates_init(model, params.chat_template);
89+
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(tmpls.get(), params.use_jinja).c_str());
90+
8491
init_vision_context(params);
92+
93+
// load antiprompt tokens for legacy templates
94+
if (params.chat_template == "vicuna") {
95+
antiprompt_tokens = common_tokenize(lctx, "ASSISTANT:", false, true);
96+
} else if (params.chat_template == "deepseek") {
97+
antiprompt_tokens = common_tokenize(lctx, "###", false, true);
98+
}
8599
}
86100

87101
void init_vision_context(common_params & params) {
@@ -97,6 +111,17 @@ struct gemma3_context {
97111
exit(1);
98112
}
99113
}
114+
115+
bool check_antiprompt(const llama_tokens & generated_tokens) {
116+
if (antiprompt_tokens.empty() || generated_tokens.size() < antiprompt_tokens.size()) {
117+
return false;
118+
}
119+
return std::equal(
120+
generated_tokens.end() - antiprompt_tokens.size(),
121+
generated_tokens.end(),
122+
antiprompt_tokens.begin()
123+
);
124+
}
100125
};
101126

102127
struct decode_embd_batch {
@@ -132,17 +157,19 @@ struct decode_embd_batch {
132157
}
133158
};
134159

135-
static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_predict) {
160+
static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
161+
llama_tokens generated_tokens;
136162
for (int i = 0; i < n_predict; i++) {
137163
if (i > n_predict || !g_is_generating) {
138164
printf("\n");
139165
break;
140166
}
141167

142168
llama_token token_id = common_sampler_sample(smpl, ctx.lctx, -1);
169+
generated_tokens.push_back(token_id);
143170
common_sampler_accept(smpl, token_id, true);
144171

145-
if (llama_vocab_is_eog(ctx.vocab, token_id)) {
172+
if (llama_vocab_is_eog(ctx.vocab, token_id) || ctx.check_antiprompt(generated_tokens)) {
146173
printf("\n");
147174
break; // end of generation
148175
}
@@ -161,7 +188,7 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_
161188
return 0;
162189
}
163190

164-
static int eval_message(gemma3_context & ctx, common_chat_msg & msg, std::vector<std::string> & images_fname, bool add_bos = false) {
191+
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vector<std::string> & images_fname, bool add_bos = false) {
165192
std::vector<mtmd_bitmap> bitmaps;
166193

167194
common_chat_templates_inputs tmpl_inputs;
@@ -218,7 +245,7 @@ int main(int argc, char ** argv) {
218245
return 1;
219246
}
220247

221-
gemma3_context ctx(params);
248+
mtmd_cli_context ctx(params);
222249
printf("%s: %s\n", __func__, params.model.path.c_str());
223250

224251
bool is_single_turn = !params.prompt.empty() && !params.image.empty();

examples/llava/mtmd.cpp

Lines changed: 90 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,20 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
102102

103103
std::string prompt_modified(text.text);
104104
std::string marker_modified(ctx->image_marker);
105-
projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
106105
// a bit hacky here, but works for now
107106
// for some models, we need to add prefix and suffix to the image embeddings
108-
if (proj_type == PROJECTOR_TYPE_GEMMA3) {
107+
if (clip_is_gemma3(ctx->ctx_clip)) {
108+
// gemma 3
109109
// <start_of_image> ... (image embeddings) ... <end_of_image>
110110
marker_modified = "<start_of_image>" + ctx->image_marker + "<end_of_image>";
111111
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
112112
}
113113

114+
// llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
115+
// for glm-edge, we don't need to add because the tokens are already in the returned embeddings
116+
117+
// TODO @ngxson : glm-edge : remove BOI / EOI tokens embeddings, decode them as normal tokens
118+
114119
std::vector<std::string> parts = string_split_str(prompt_modified, ctx->image_marker);
115120
output.clear();
116121
output.reserve(parts.size());
@@ -155,11 +160,20 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
155160
}
156161

157162
mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
158-
image_tokens->nx = clip_n_patches(ctx->ctx_clip); // TODO @ngxson : use clip_n_patches_by_image
163+
image_tokens->nx = clip_n_patches(ctx->ctx_clip) * batch_f32.entries.size(); // TODO @ngxson : use clip_n_patches_by_image
159164
image_tokens->ny = 1; // TODO
160165
image_tokens->batch_f32 = std::move(batch_f32);
161166
image_tokens->id = bitmaps[i_img].id; // optional
162167

168+
LOG_DBG("image_tokens->nx = %d\n", image_tokens->nx);
169+
LOG_DBG("image_tokens->ny = %d\n", image_tokens->ny);
170+
LOG_DBG("batch_f32 size = %d\n", (int)image_tokens->batch_f32.entries.size());
171+
172+
if (clip_is_glm(ctx->ctx_clip)) {
173+
// glm-edge
174+
image_tokens->nx += 2; // add 2 for the begin_of_image and end_of_image token embeddings
175+
}
176+
163177
mtmd_input_chunk chunk{
164178
MTMD_INPUT_CHUNK_TYPE_IMAGE,
165179
{},
@@ -198,11 +212,27 @@ std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
198212
int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
199213
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
200214
ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
201-
bool ok = clip_image_batch_encode(
202-
ctx->ctx_clip,
203-
ctx->n_threads,
204-
&image_tokens->batch_f32,
205-
ctx->image_embd_v.data());
215+
bool ok = false;
216+
217+
if (clip_is_llava(ctx->ctx_clip)) {
218+
// TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
219+
const auto & entries = image_tokens->batch_f32.entries;
220+
for (size_t i = 0; i < entries.size(); i++) {
221+
int n_tokens_per_image = clip_n_patches(ctx->ctx_clip);
222+
ok = clip_image_encode(
223+
ctx->ctx_clip,
224+
ctx->n_threads,
225+
entries[i].get(),
226+
ctx->image_embd_v.data() + i*n_mmproj_embd*n_tokens_per_image);
227+
}
228+
} else {
229+
ok = clip_image_batch_encode(
230+
ctx->ctx_clip,
231+
ctx->n_threads,
232+
&image_tokens->batch_f32,
233+
ctx->image_embd_v.data());
234+
}
235+
206236
return ok ? 0 : 1;
207237
}
208238

@@ -268,28 +298,31 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
268298
int32_t ret;
269299
llama_pos n_past = pos0;
270300
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
301+
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
271302

272303
for (auto & chunk : chunks) {
273304
bool is_last = &chunk == &chunks.back();
274305
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
275-
// TODO @ngxson : may need to split into smaller batches
276306
text_batch.n_tokens = chunk.tokens_text.size();
277-
for (size_t i = 0; i < chunk.tokens_text.size(); i++) {
278-
text_batch.token [i] = chunk.tokens_text[i];
279-
text_batch.pos [i] = n_past++;
280-
text_batch.n_seq_id[i] = 1;
281-
text_batch.seq_id [i][0] = seq_id;
282-
text_batch.logits [i] = false;
283-
}
284-
if (is_last) {
285-
// always get logits for last input chunk
286-
text_batch.logits[text_batch.n_tokens - 1] = true;
287-
}
288-
ret = llama_decode(lctx, text_batch);
289-
if (ret != 0) {
290-
LOG_ERR("failed to decode text\n");
291-
llama_batch_free(text_batch);
292-
return ret;
307+
size_t i = 0;
308+
while (i < chunk.tokens_text.size()) { // split into batches
309+
for (; i < chunk.tokens_text.size() && text_batch.n_tokens < n_batch; i++) {
310+
text_batch.token [i] = chunk.tokens_text[i];
311+
text_batch.pos [i] = n_past++;
312+
text_batch.n_seq_id[i] = 1;
313+
text_batch.seq_id [i][0] = seq_id;
314+
text_batch.logits [i] = false;
315+
}
316+
if (is_last) {
317+
// always get logits for last input chunk
318+
text_batch.logits[text_batch.n_tokens - 1] = true;
319+
}
320+
ret = llama_decode(lctx, text_batch);
321+
if (ret != 0) {
322+
LOG_ERR("failed to decode text\n");
323+
llama_batch_free(text_batch);
324+
return ret;
325+
}
293326
}
294327

295328
} else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
@@ -310,20 +343,42 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
310343
}
311344

312345
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get());
346+
int32_t i_batch = 0;
347+
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
313348
float * embd = mtmd_get_output_embd(ctx);
314-
decode_embd_batch batch_img(embd, n_tokens, n_past, 0);
315-
int64_t t1 = ggml_time_ms();
316-
ret = llama_decode(lctx, batch_img.batch);
317-
if (ret != 0) {
318-
LOG_ERR("failed to decode image\n");
319-
llama_batch_free(text_batch);
320-
return ret;
349+
350+
if (mtmd_decode_use_non_causal(ctx)) {
351+
llama_set_causal_attn(lctx, false);
321352
}
322-
if (ctx->print_timings) {
323-
LOG_INF("image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1);
353+
354+
while (i_batch < n_img_batches) { // split into batches
355+
int32_t pos_offset = i_batch*n_batch;
356+
int32_t n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
357+
float * embd_batch = embd + pos_offset*n_mmproj_embd;
358+
decode_embd_batch batch_img(embd_batch, n_tokens_batch, n_past, 0);
359+
360+
printf("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
361+
362+
int64_t t1 = ggml_time_ms();
363+
ret = llama_decode(lctx, batch_img.batch);
364+
if (ret != 0) {
365+
LOG_ERR("failed to decode image\n");
366+
llama_set_causal_attn(lctx, true); // restore causal attn
367+
llama_batch_free(text_batch);
368+
return ret;
369+
}
370+
371+
if (ctx->print_timings) {
372+
LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
373+
}
374+
375+
i_batch++;
376+
n_past += n_tokens_batch;
324377
}
325378

326-
n_past += n_tokens;
379+
if (mtmd_decode_use_non_causal(ctx)) {
380+
llama_set_causal_attn(lctx, true);
381+
}
327382

328383
} else {
329384
GGML_ASSERT(false && "chunk type not supported");

0 commit comments

Comments
 (0)