Skip to content

Commit 4217d42

Browse files
committed
handle llama 4 preprocessing
1 parent b74122f commit 4217d42

File tree

1 file changed

+54
-12
lines changed

1 file changed

+54
-12
lines changed

tools/mtmd/mtmd.cpp

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,18 @@ struct mtmd_context {
6464
int n_threads;
6565
std::string image_marker;
6666

67-
// for minicpmv, we need special tokens in-between slices
67+
// for llava-uhd style models, we need special tokens in-between slices
6868
mtmd_slice_tmpl slice_tmpl = MTMD_SLICE_TMPL_NONE;
6969
llama_token tok_ov_img_start = LLAMA_TOKEN_NULL; // overview image
7070
llama_token tok_ov_img_end = LLAMA_TOKEN_NULL; // overview image
7171
llama_token tok_slices_start = LLAMA_TOKEN_NULL; // start of all slices
7272
llama_token tok_slices_end = LLAMA_TOKEN_NULL; // end of all slices
73-
llama_token tok_sli_img_start = LLAMA_TOKEN_NULL; // single slice
74-
llama_token tok_sli_img_end = LLAMA_TOKEN_NULL; // single slice
73+
llama_token tok_sli_img_start = LLAMA_TOKEN_NULL; // single slice start
74+
llama_token tok_sli_img_end = LLAMA_TOKEN_NULL; // single slice end
75+
llama_token tok_sli_img_mid = LLAMA_TOKEN_NULL; // between 2 slices
7576
llama_token tok_row_end = LLAMA_TOKEN_NULL; // end of row
77+
bool tok_row_end_trail = false;
78+
bool ov_img_first = false;
7679

7780
bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE
7881

@@ -96,6 +99,7 @@ struct mtmd_context {
9699

97100
use_mrope = clip_is_qwen2vl(ctx_clip);
98101

102+
projector_type proj = clip_get_projector_type(ctx_clip);
99103
int minicpmv_version = clip_is_minicpmv(ctx_clip);
100104
if (minicpmv_version == 2) {
101105
// minicpmv 2.5 format:
@@ -108,6 +112,8 @@ struct mtmd_context {
108112
tok_sli_img_start = tok_ov_img_start;
109113
tok_sli_img_end = tok_ov_img_end;
110114
tok_row_end = lookup_token("\n");
115+
tok_row_end_trail = false; // no trailing end-of-row token
116+
ov_img_first = true;
111117

112118
} else if (minicpmv_version == 3 || minicpmv_version == 4) {
113119
// minicpmv 2.6 format:
@@ -118,9 +124,24 @@ struct mtmd_context {
118124
tok_sli_img_start = lookup_token("<slice>");
119125
tok_sli_img_end = lookup_token("</slice>");
120126
tok_row_end = lookup_token("\n");
127+
tok_row_end_trail = false; // no trailing end-of-row token
128+
ov_img_first = true;
121129

122130
} else if (minicpmv_version != 0) {
123131
GGML_ASSERT(false && "unsupported minicpmv version");
132+
} else if (proj == PROJECTOR_TYPE_LLAMA4) {
133+
// llama 4 format:
134+
// <|image_start|>
135+
// (slice) <|tile_x_separator|> (slice) <|tile_x_separator|> ... <|tile_y_separator|>
136+
// (slice) <|tile_x_separator|> (slice) <|tile_x_separator|> ... <|tile_y_separator|>
137+
// ... <|tile_y_separator|> <-- trailing end-of-row token
138+
// <|image|> (overview) <-- overview image is last
139+
// <|image_end|>
140+
tok_ov_img_start = lookup_token("<|image|>");
141+
tok_sli_img_mid = lookup_token("<|tile_x_separator|>");
142+
tok_row_end = lookup_token("<|tile_y_separator|>");
143+
tok_row_end_trail = true; // add trailing end-of-row token
144+
ov_img_first = false; // overview image is last
124145
}
125146
}
126147

@@ -250,13 +271,11 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
250271
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
251272

252273
} else if (proj_type == PROJECTOR_TYPE_LLAMA4) {
253-
// <|image_start|> ... (image embeddings) ... <|image_end|>
274+
// (more details in mtmd_context constructor)
254275
marker_modified = "<|image_start|>" + ctx->image_marker + "<|image_end|>";
255276
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
256277

257-
}
258-
259-
else if (proj_type == PROJECTOR_TYPE_INTERNVL) {
278+
} else if (proj_type == PROJECTOR_TYPE_INTERNVL) {
260279
// <img> ... (image embeddings) ... </img>
261280
marker_modified = "<img>" + ctx->image_marker + "</img>";
262281
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
@@ -347,11 +366,19 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
347366
auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_img]->id);
348367
GGML_ASSERT(chunks.size() > 0);
349368

350-
// add overview image
351-
add_text_chunk({ctx->tok_ov_img_start});
352-
output->entries.emplace_back(std::move(chunks.front()));
369+
auto ov_chunk = std::move(chunks.front());
353370
chunks.erase(chunks.begin());
354-
add_text_chunk({ctx->tok_ov_img_end});
371+
372+
// add overview image (first)
373+
if (ctx->ov_img_first) {
374+
if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) {
375+
add_text_chunk({ctx->tok_ov_img_start});
376+
}
377+
output->entries.emplace_back(std::move(ov_chunk));
378+
if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) {
379+
add_text_chunk({ctx->tok_ov_img_end});
380+
}
381+
}
355382

356383
// add slices
357384
if (!chunks.empty()) {
@@ -364,15 +391,19 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
364391
}
365392
for (int y = 0; y < n_row; y++) {
366393
for (int x = 0; x < n_col; x++) {
394+
const bool is_last_in_row = (x == n_col - 1);
367395
if (ctx->tok_sli_img_start != LLAMA_TOKEN_NULL) {
368396
add_text_chunk({ctx->tok_sli_img_start});
369397
}
370398
output->entries.emplace_back(std::move(chunks[y * n_col + x]));
371399
if (ctx->tok_sli_img_end != LLAMA_TOKEN_NULL) {
372400
add_text_chunk({ctx->tok_sli_img_end});
373401
}
402+
if (!is_last_in_row && ctx->tok_sli_img_mid != LLAMA_TOKEN_NULL) {
403+
add_text_chunk({ctx->tok_sli_img_mid});
404+
}
374405
}
375-
if (ctx->tok_row_end != LLAMA_TOKEN_NULL && y != n_row - 1) {
406+
if ((y != n_row - 1 || ctx->tok_row_end_trail) && ctx->tok_row_end != LLAMA_TOKEN_NULL) {
376407
add_text_chunk({ctx->tok_row_end});
377408
}
378409
}
@@ -381,6 +412,17 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
381412
}
382413
}
383414

415+
// add overview image (last)
416+
if (!ctx->ov_img_first) {
417+
if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) {
418+
add_text_chunk({ctx->tok_ov_img_start});
419+
}
420+
output->entries.emplace_back(std::move(ov_chunk));
421+
if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) {
422+
add_text_chunk({ctx->tok_ov_img_end});
423+
}
424+
}
425+
384426
} else {
385427
size_t n_tokens = 0;
386428
for (const auto & entry : batch_f32.entries) {

0 commit comments

Comments
 (0)