Skip to content

Commit e172313

Browse files
committed
feat: Partial support for full templating for idefics3 in mtmd
There are still errors encoding some of the image chunks, but the token sequence now matches transformers _almost_ perfectly, except for the double newline before the global image which shows up as two consecutive newline tokens instead of a single double-newline token. I think this is happening because the blocks are tokenized separately then concatenated. Branch: gabe-l-hart/GraniteDocling Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 8819c96 commit e172313

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

tools/mtmd/mtmd.cpp

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ enum mtmd_slice_tmpl {
7676
MTMD_SLICE_TMPL_MINICPMV_2_5,
7777
MTMD_SLICE_TMPL_MINICPMV_2_6,
7878
MTMD_SLICE_TMPL_LLAMA4,
79-
// TODO @ngxson : add support for idefics (SmolVLM)
79+
MTMD_SLICE_TMPL_IDEFICS3,
8080
};
8181

8282
const char * mtmd_default_marker() {
@@ -127,6 +127,9 @@ struct mtmd_context {
127127

128128
bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE
129129

130+
// string template for slice image delimiters with row/col (idefics3)
131+
std::string sli_img_start_tmpl;
132+
130133
// for whisper, we pre-calculate the mel filter bank
131134
whisper_preprocessor::whisper_filters w_filters;
132135

@@ -245,8 +248,12 @@ struct mtmd_context {
245248

246249
} else if (proj == PROJECTOR_TYPE_IDEFICS3) {
247250
// https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
248-
img_beg = "<fake_token_around_image><global-img>";
249-
img_end = "<fake_token_around_image>";
251+
slice_tmpl = MTMD_SLICE_TMPL_IDEFICS3;
252+
tok_ov_img_start = {lookup_token("\n"), lookup_token("<fake_token_around_image>"), lookup_token("<global-img>")};
253+
tok_ov_img_end = {lookup_token("<fake_token_around_image>")};
254+
tok_row_end = {lookup_token("\n")};
255+
img_beg = "<fake_token_around_image>";
256+
sli_img_start_tmpl = "<fake_token_around_image><row_%d_col_%d>";
250257

251258
} else if (proj == PROJECTOR_TYPE_PIXTRAL) {
252259
// https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
@@ -504,6 +511,7 @@ struct mtmd_tokenizer {
504511
ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5
505512
|| ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6
506513
|| ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4
514+
|| ctx->slice_tmpl == MTMD_SLICE_TMPL_IDEFICS3
507515
) {
508516
const int n_col = batch_f32.grid_x;
509517
const int n_row = batch_f32.grid_y;
@@ -537,6 +545,12 @@ struct mtmd_tokenizer {
537545
const bool is_last_in_row = (x == n_col - 1);
538546
if (!ctx->tok_sli_img_start.empty()) {
539547
add_text(ctx->tok_sli_img_start);
548+
} else if (!ctx->sli_img_start_tmpl.empty()) {
549+
// If using a template to preceed a slice image
550+
const size_t sz = std::snprintf(nullptr, 0, ctx->sli_img_start_tmpl.c_str(), y+1, x+1) + 1;
551+
std::unique_ptr<char[]> buf(new char[sz]);
552+
std::snprintf(buf.get(), sz, ctx->sli_img_start_tmpl.c_str(), y+1, x+1);
553+
add_text(std::string(buf.get(), buf.get() + sz - 1), true);
540554
}
541555
cur.entries.emplace_back(std::move(chunks[y * n_col + x]));
542556
if (!ctx->tok_sli_img_end.empty()) {
@@ -780,7 +794,10 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens)
780794
ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
781795
bool ok = false;
782796

783-
if (clip_is_llava(ctx_clip) || clip_is_minicpmv(ctx_clip) || clip_is_glm(ctx_clip)) {
797+
if (clip_is_llava(ctx_clip)
798+
|| clip_is_minicpmv(ctx_clip)
799+
|| clip_is_glm(ctx_clip)
800+
|| clip_is_idefics3(ctx_clip)) {
784801
// TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
785802
const auto & entries = image_tokens->batch_f32.entries;
786803
for (size_t i = 0; i < entries.size(); i++) {

0 commit comments

Comments
 (0)