Skip to content

Commit 790bbb9

Browse files
committed
sam warmup working
1 parent b32bb5e commit 790bbb9

File tree

2 files changed

+13
-20
lines changed

2 files changed

+13
-20
lines changed

tools/mtmd/clip-impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@
133133
#define TN_SAM_POS_EMBD "v.sam.pos_embd"
134134
#define TN_SAM_PATCH_EMBD "v.sam.patch_embd.%s"
135135
#define TN_SAM_PRE_NORM "v.sam.blk.%d.pre_ln.%s"
136-
#define TN_SAM_POST_NORM "v.sam.blk.%d.post_ln"
136+
#define TN_SAM_POST_NORM "v.sam.blk.%d.post_ln.%s"
137137
#define TN_SAM_ATTN_POS_H "v.sam.blk.%d.attn.pos_h"
138138
#define TN_SAM_ATTN_POS_W "v.sam.blk.%d.attn.pos_w"
139139
#define TN_SAM_ATTN_QKV "v.sam.blk.%d.attn.qkv.%s"

tools/mtmd/clip.cpp

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -225,17 +225,7 @@ struct clip_hparams {
225225

226226
// sam vit deepseek-ocr
227227
std::vector<int32_t> global_attn_indices() const {
228-
switch (n_embd) {
229-
case 768: return { 2, 5, 8, 11 };
230-
case 1024: return { 5, 11, 17, 23 };
231-
case 1280: return { 7, 15, 23, 31 };
232-
default:
233-
{
234-
fprintf(stderr, "%s: unsupported n_enc_state = %d\n", __func__, n_embd);
235-
} break;
236-
};
237-
238-
return {};
228+
return { 2, 5, 8, 11 };
239229
}
240230

241231
bool is_global_attn(int32_t layer) const {
@@ -455,7 +445,7 @@ struct clip_model {
455445
ggml_tensor * net_2;
456446
ggml_tensor * net_3;
457447

458-
int32_t n_sam_layers = 0; // used by deepseek-ocr sam encoder
448+
int32_t n_sam_layers = 12; // used by deepseek-ocr sam encoder
459449

460450
std::vector<clip_layer> sam_layers;
461451

@@ -721,7 +711,7 @@ struct clip_graph {
721711
Qcur = ggml_reshape_3d(ctx0, Qcur, enc_d_heads, W * H, B * enc_n_heads);
722712

723713
ggml_tensor * Kcur =
724-
ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], 1 * cur->nb[3]);
714+
ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], cur->nb[3]);
725715
Kcur = ggml_reshape_4d(ctx0, Kcur, enc_d_heads, enc_n_heads, W * H, B);
726716
Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
727717
Kcur = ggml_reshape_3d(ctx0, Kcur, enc_d_heads, W * H, B * enc_n_heads);
@@ -740,12 +730,12 @@ struct clip_graph {
740730

741731
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcur, Qcur);
742732

743-
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf(enc_n_heads));
733+
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf(enc_d_heads));
744734

745735
struct ggml_tensor * rw = ggml_get_rel_pos(ctx0, layer.rel_pos_w, W, W);
746736
struct ggml_tensor * rh = ggml_get_rel_pos(ctx0, layer.rel_pos_h, H, H);
747737

748-
struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_n_heads, W, H, B * enc_n_embd);
738+
struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, W, H, B * enc_n_heads);
749739

750740
struct ggml_tensor * rel_w = ggml_cont(
751741
ctx0,
@@ -763,7 +753,7 @@ struct clip_graph {
763753
ctx0,
764754
ggml_cont(ctx0, ggml_permute(ctx0, ggml_reshape_4d(ctx0, KQV, enc_d_heads, W * H, enc_n_heads, B),
765755
0, 2, 1, 3)),
766-
n_embd, W, H, B);
756+
enc_n_embd, W, H, B);
767757

768758
cur = ggml_mul_mat(ctx0, layer.o_w, cur);
769759
cur = ggml_add_inplace(ctx0, cur, layer.o_b);
@@ -2492,9 +2482,11 @@ struct clip_graph {
24922482
// patch_embed_proj_w shape = [768, 3, 16, 16]
24932483
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embed_proj_w, inp_raw, enc_patch_size, enc_patch_size, 0, 0,
24942484
1, 1); // [64, 64, 768]
2495-
inp = ggml_reshape_2d(ctx0, inp, enc_n_patches, enc_n_embd); // [4096, 768]
2485+
inp = ggml_reshape_2d(ctx0, inp, enc_n_patches * enc_n_patches, enc_n_embd); // [4096, 768]
24962486
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); // [768, 4096]
24972487
inp = ggml_add(ctx0, inp, model.patch_embed_proj_b);
2488+
inp = ggml_cont(ctx0, inp);
2489+
inp = ggml_reshape_4d(ctx0, inp, enc_n_embd, enc_n_patches, enc_n_patches, 1);
24982490
cb(inp, "enc_patch_bias", -1);
24992491
return inp;
25002492
}
@@ -3193,8 +3185,9 @@ struct clip_model_loader {
31933185
} break;
31943186
case PROJECTOR_TYPE_DEEPSEEKOCR:
31953187
{
3196-
hparams.set_limit_image_tokens(8, 1024);
3197-
hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
3188+
hparams.patch_size = 16;
3189+
hparams.image_size = 1024;
3190+
hparams.warmup_image_size = 1024;
31983191
} break;
31993192
default:
32003193
break;

0 commit comments

Comments
 (0)