Skip to content

Commit 1268dc3

Browse files
committed
Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into sf/deepseek-ocr
2 parents a65ddf5 + 88032f4 commit 1268dc3

File tree

4 files changed

+86
-39
lines changed

4 files changed

+86
-39
lines changed

tools/mtmd/clip-impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191
#define TN_MM_INP_NORM_B "mm.input_norm.bias"
9292
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
9393
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
94-
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
94+
#define TN_MM_PROJECTOR "mm.model.fc.%s" // idefics3, deepseekocr
9595
#define TN_MM_PATCH_MERGER "mm.patch_merger.weight" // mistral small 3.1
9696
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
9797
#define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model)

tools/mtmd/clip.cpp

Lines changed: 81 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,8 @@ struct clip_model {
316316
ggml_tensor * post_ln_w;
317317
ggml_tensor * post_ln_b;
318318

319-
ggml_tensor * projection; // TODO: rename it to fc (fully connected layer)
319+
ggml_tensor * fc_w;
320+
ggml_tensor * fc_b;
320321
ggml_tensor * mm_fc_w;
321322
ggml_tensor * mm_fc_b;
322323

@@ -623,7 +624,7 @@ struct clip_graph {
623624
// https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
624625
const int scale_factor = model.hparams.n_merge;
625626
cur = build_patch_merge_permute(cur, scale_factor);
626-
cur = ggml_mul_mat(ctx0, model.projection, cur);
627+
cur = ggml_mul_mat(ctx0, model.fc_w, cur);
627628

628629
} else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
629630
// pixel unshuffle block
@@ -689,7 +690,8 @@ struct clip_graph {
689690
if (hparams.is_global_attn(il) == false) {
690691
// local attention layer - apply window partition
691692
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L169-L172
692-
cur = ggml_win_part(ctx0, cur, 14);
693+
//cur = ggml_win_part(ctx0, cur, 14);
694+
cur = window_partition(ctx0, cur, 14);
693695
}
694696

695697
const int64_t W = cur->ne[1];
@@ -761,7 +763,7 @@ struct clip_graph {
761763

762764
if (hparams.is_global_attn(il) == false) {
763765
// local attention layer - reverse window partition
764-
cur = ggml_win_unpart(ctx0, cur, w0, h0, 14);
766+
cur = window_unpartition(ctx0, cur, w0, h0, 14);
765767
}
766768

767769
// re-add the layer input, e.g., residual
@@ -844,15 +846,12 @@ struct clip_graph {
844846
ggml_row_size(global_features_2->type, n_embd), 0);
845847

846848
ggml_tensor * global_features = ggml_concat(ctx0, global_features_2, global_features_1, 1);
847-
global_features = build_global_local_features(
848-
ctx0,
849-
global_features,
850-
n_patches_y,
851-
n_patches_x,
852-
n_embd
853-
);
849+
global_features = ggml_reshape_2d(ctx0, global_features, 2* n_embd, n_patches);
850+
global_features = ggml_cont(ctx0, global_features);
851+
global_features = ggml_mul_mat(ctx0, model.fc_w, global_features);
852+
global_features = ggml_add(ctx0, global_features, model.fc_b);
853+
global_features = build_global_local_features(ctx0,global_features);
854854
ggml_build_forward_expand(gf, global_features);
855-
856855
return gf;
857856
}
858857

@@ -861,41 +860,32 @@ struct clip_graph {
861860
// view_separator: [n_dim]
862861

863862
ggml_tensor * build_global_local_features(ggml_context * ctx0,
864-
ggml_tensor * global_features,
865-
int h,
866-
int w,
867-
int n_dim) {
863+
ggml_tensor * global_features) {
868864
GGML_ASSERT(model.image_newline != nullptr);
869865
GGML_ASSERT(model.view_seperator != nullptr);
870-
GGML_ASSERT(global_features->ne[0] == static_cast<int64_t>(n_dim));
871-
GGML_ASSERT(global_features->ne[1] == static_cast<int64_t>(2 * (h * w)));
872866

873867
// 1) global_features: [n_dim, h*w] -> [n_dim, w, h] -> [h, w, n_dim]
874-
ggml_tensor * t = ggml_reshape_3d(ctx0, global_features, n_dim, w, h); // (n_dim, w, h)
875-
t = ggml_permute(ctx0, t, 2, 1, 0, 3); // (h, w, n_dim)
876-
877-
// 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim]
878-
ggml_tensor * nl = ggml_reshape_3d(ctx0, model.image_newline, 1, 1, n_dim); // (1, 1, n_dim)
868+
ggml_tensor * t = ggml_reshape_4d(ctx0, global_features, 1280, 64, 64, 1); // (n_dim, w, h)
869+
t = ggml_cont(ctx0, ggml_permute(ctx0, t, 2, 1, 0, 3)); // (h, w, n_dim)
870+
ggml_tensor * nl = ggml_cont(ctx0,ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3));
871+
nl = ggml_repeat_4d(ctx0, nl, 64, 1, 1280, 1); // n_pos rows
872+
nl = ggml_cont(ctx0, nl);
879873

880-
ggml_tensor * nl_target_shape = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, h, n_dim); // (1, h, n_dim)
881-
nl = ggml_repeat(ctx0, nl, nl_target_shape); // (1, h, n_dim)
882-
nl = ggml_permute(ctx0, nl, 1, 0, 2, 3); // (h, 1, n_dim)
883874

884-
// 3) concat along width dimension (dim=1): (h, w, n_dim) + (h, 1, n_dim) -> (h, w+1, n_dim)
875+
// 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim]
885876
t = ggml_concat(ctx0, t, nl, 1); // (h, w+1, n_dim)
886877

887-
// 4) flatten back to token axis: (h, w+1, n_dim) -> (n_dim, h*(w+1))
888-
t = ggml_permute(ctx0, t, 2, 1, 0, 3); // (n_dim, w+1, h)
889-
t = ggml_cont_2d(ctx0, t, n_dim, (w + 1) * h); // (n_dim, h*(w+1))
878+
t = ggml_reshape_2d(ctx0, t, 1280, 64 * (64 + 1)); // (n_dim, h*(w+1))
879+
890880

891881
// 5) append view_separator as an extra "token":
892882
// view_separator: [n_dim] -> [n_dim, 1]
893-
ggml_tensor * vs = ggml_reshape_2d(ctx0, model.view_seperator, n_dim, 1); // (n_dim, 1)
883+
ggml_tensor * vs = ggml_reshape_2d(ctx0, model.view_seperator, 1280, 1); // (n_dim, 1)
894884

895885
// concat along token dimension (dim=1):
896-
ggml_tensor * global_local_features = ggml_concat(ctx0, t, vs, 1); // (n_dim, h*(w+1) + 1)
886+
t = ggml_concat(ctx0, t, vs, 1); // (n_dim, h*(w+1) + 1)
897887

898-
return global_local_features;
888+
return t;
899889
}
900890

901891

@@ -2476,6 +2466,46 @@ struct clip_graph {
24762466
return inpL;
24772467
}
24782468

2469+
static ggml_tensor* window_partition(ggml_context* ctx, ggml_tensor* x, int window) {
2470+
auto [c, w, h, b] = x->ne;
2471+
// same as
2472+
// x = ggml_win_part(m, x, window);
2473+
// x = ggml_reshape_3d(m, x, c, window * window, x->ne[3]);
2474+
2475+
int64_t px = (window - w % window) % window;
2476+
int64_t py = (window - h % window) % window;
2477+
int64_t npw = (w + px) / window;
2478+
int64_t nph = (h + py) / window;
2479+
2480+
if (px > 0 || py > 0) {
2481+
x = ggml_pad(ctx, x, 0, int(px), int(py), 0);
2482+
}
2483+
x = ggml_reshape_4d(ctx, x, c * window, npw, window, nph * b);
2484+
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3));
2485+
x = ggml_reshape_4d(ctx, x, c, window ,window, npw * nph * b);
2486+
return x;
2487+
}
2488+
2489+
static ggml_tensor* window_unpartition(ggml_context* m, ggml_tensor* x, int w, int h, int window) {
2490+
int64_t c = x->ne[0];
2491+
// same as
2492+
// x = ggml_reshape_4d(m, x, c, window, window, x->ne[2]);
2493+
// x = ggml_win_unpart(m, x, w, h, window);
2494+
2495+
int64_t px = (window - w % window) % window;
2496+
int64_t py = (window - h % window) % window;
2497+
int64_t npw = (w + px) / window;
2498+
int64_t nph = (h + py) / window;
2499+
2500+
int64_t b = x->ne[3] / (npw * nph);
2501+
x = ggml_reshape_4d(m, x, c * window, window, npw, nph * b);
2502+
x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3));
2503+
x = ggml_reshape_4d(m, x, c, w + px, h + py, b);
2504+
x = ggml_view_4d(m, x, x->ne[0], w, h, x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0);
2505+
x = ggml_cont(m, x);
2506+
return x;
2507+
}
2508+
24792509
// build the input after conv2d (inp_raw --> patches)
24802510
// returns tensor with shape [n_embd, n_patches]
24812511
ggml_tensor * build_enc_inp(ggml_tensor * inp_raw,
@@ -3488,7 +3518,7 @@ struct clip_model_loader {
34883518
} break;
34893519
case PROJECTOR_TYPE_IDEFICS3:
34903520
{
3491-
model.projection = get_tensor(TN_MM_PROJECTOR);
3521+
model.fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight"));
34923522
} break;
34933523
case PROJECTOR_TYPE_LFM2:
34943524
case PROJECTOR_TYPE_KIMIVL:
@@ -3561,13 +3591,13 @@ struct clip_model_loader {
35613591
} break;
35623592
case PROJECTOR_TYPE_LLAMA4:
35633593
{
3564-
model.mm_model_proj = get_tensor(TN_MM_PROJECTOR);
3594+
model.mm_model_proj = get_tensor(string_format(TN_MM_PROJECTOR, "weight"));
35653595
model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
35663596
model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
35673597
} break;
35683598
case PROJECTOR_TYPE_COGVLM:
35693599
{
3570-
model.mm_model_proj = get_tensor(TN_MM_PROJECTOR);
3600+
model.mm_model_proj = get_tensor(string_format(TN_MM_PROJECTOR, "weight"));
35713601
model.mm_post_fc_norm_w = get_tensor(string_format(TN_MM_POST_FC_NORM, "weight"));
35723602
model.mm_post_fc_norm_b = get_tensor(string_format(TN_MM_POST_FC_NORM, "bias"));
35733603
model.mm_h_to_4h_w = get_tensor(string_format(TN_MM_H_TO_4H, "weight"));
@@ -3617,6 +3647,9 @@ struct clip_model_loader {
36173647
}
36183648
model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false);
36193649
model.view_seperator = get_tensor(TN_IMAGE_SEPERATOR, false);
3650+
model.fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight"));
3651+
model.fc_b = get_tensor(string_format(TN_MM_PROJECTOR, "bias"));
3652+
36203653

36213654
break;
36223655
default:
@@ -5086,6 +5119,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
50865119
{
50875120
n_patches += 2; // for BOI and EOI token embeddings
50885121
} break;
5122+
case PROJECTOR_TYPE_DEEPSEEKOCR:
5123+
{
5124+
n_patches += 2;
5125+
} break;
50895126
default:
50905127
GGML_ABORT("unsupported projector type");
50915128
}
@@ -5417,6 +5454,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
54175454
case PROJECTOR_TYPE_VOXTRAL:
54185455
case PROJECTOR_TYPE_JANUS_PRO:
54195456
case PROJECTOR_TYPE_COGVLM:
5457+
case PROJECTOR_TYPE_DEEPSEEKOCR:
54205458
{
54215459
// do nothing
54225460
} break;
@@ -5512,7 +5550,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
55125550
case PROJECTOR_TYPE_GEMMA3:
55135551
return ctx->model.mm_input_proj_w->ne[0];
55145552
case PROJECTOR_TYPE_IDEFICS3:
5515-
return ctx->model.projection->ne[1];
5553+
return ctx->model.fc_w->ne[1];
55165554
case PROJECTOR_TYPE_ULTRAVOX:
55175555
case PROJECTOR_TYPE_VOXTRAL:
55185556
return ctx->model.mm_2_w->ne[1];
@@ -5527,6 +5565,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
55275565
return ctx->model.mm_2_w->ne[1];
55285566
case PROJECTOR_TYPE_COGVLM:
55295567
return ctx->model.mm_4h_to_h_w->ne[1];
5568+
case PROJECTOR_TYPE_DEEPSEEKOCR:
5569+
return ctx->model.fc_w->ne[1];
55305570
default:
55315571
GGML_ABORT("Unknown projector type");
55325572
}
@@ -5557,6 +5597,10 @@ bool clip_is_gemma3(const struct clip_ctx * ctx) {
55575597
return ctx->proj_type() == PROJECTOR_TYPE_GEMMA3;
55585598
}
55595599

5600+
bool clip_is_deepseekocr(const struct clip_ctx * ctx) {
5601+
return ctx->proj_type() == PROJECTOR_TYPE_DEEPSEEKOCR;
5602+
}
5603+
55605604
bool clip_has_vision_encoder(const struct clip_ctx * ctx) {
55615605
return ctx->model.modality == CLIP_MODALITY_VISION;
55625606
}

tools/mtmd/clip.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ bool clip_is_glm(const struct clip_ctx * ctx);
105105
bool clip_is_qwen2vl(const struct clip_ctx * ctx);
106106
bool clip_is_llava(const struct clip_ctx * ctx);
107107
bool clip_is_gemma3(const struct clip_ctx * ctx);
108+
bool clip_is_deepseekocr(const struct clip_ctx * ctx);
109+
108110

109111
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
110112

tools/mtmd/mtmd.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,8 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens)
810810

811811
if (clip_is_llava(ctx_clip)
812812
|| clip_is_minicpmv(ctx_clip)
813-
|| clip_is_glm(ctx_clip)) {
813+
|| clip_is_glm(ctx_clip)
814+
|| clip_is_deepseekocr(ctx_clip)) {
814815
// TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
815816
const auto & entries = image_tokens->batch_f32.entries;
816817
for (size_t i = 0; i < entries.size(); i++) {

0 commit comments

Comments
 (0)