Skip to content

Commit b6b9f02

Browse files
committed
loading sam tensors
1 parent 43a130b commit b6b9f02

File tree

2 files changed

+63
-35
lines changed

2 files changed

+63
-35
lines changed

tools/mtmd/clip-impl.h

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -131,18 +131,17 @@
131131

132132
// deepseek-ocr
133133
#define TN_SAM_POS_EMBD "sam.pos_embd"
134-
#define TN_SAM_PATCH_EMBD "sam.patch_embd"
135-
#define TN_SAM_PRE_NORM "sam.blk.%d.pre_ln"
134+
#define TN_SAM_PATCH_EMBD "sam.patch_embd.%s"
135+
#define TN_SAM_PRE_NORM "sam.blk.%d.pre_ln.%s"
136136
#define TN_SAM_POST_NORM "sam.blk.%d.post_ln"
137137
#define TN_SAM_ATTN_POS_H "sam.blk.%d.attn.pos_h"
138138
#define TN_SAM_ATTN_POS_W "sam.blk.%d.attn.pos_w"
139-
#define TN_SAM_ATTN_QKV "sam.blk.%d.attn.qkv"
140-
#define TN_SAM_ATTN_OUT "sam.blk.%d.attn.out"
141-
#define TN_SAM_MLP_LIN_1 "sam.blk.%d.mlp.lin1"
142-
#define TN_SAM_MLP_LIN_2 "sam.blk.%d.mlp.lin2"
143-
#define TN_SAM_NECK "sam.neck.%d"
144-
#define TN_SAM_NET_2 "sam.net_2"
145-
#define TN_SAM_NET_3 "sam.net_3"
139+
#define TN_SAM_ATTN_QKV "sam.blk.%d.attn.qkv.%s"
140+
#define TN_SAM_ATTN_OUT "sam.blk.%d.attn.out.%s"
141+
#define TN_SAM_FFN_UP "sam.blk.%d.mlp.lin1.%s"
142+
#define TN_SAM_FFN_DOWN "sam.blk.%d.mlp.lin2.%s"
143+
#define TN_SAM_NECK "sam.neck.%d.%s"
144+
#define TN_SAM_NET "sam.net_%d.%s"
146145

147146

148147
#define TN_SAM_ATTN_OUT "sam.blk.%d.attn_out"

tools/mtmd/clip.cpp

Lines changed: 55 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -446,14 +446,18 @@ struct clip_model {
446446
return proj_type == PROJECTOR_TYPE_ULTRAVOX
447447
|| proj_type == PROJECTOR_TYPE_VOXTRAL;
448448
}
449-
ggml_tensor * neck_conv_0;
450-
ggml_tensor * neck_norm_0_w;
451-
ggml_tensor * neck_norm_0_b;
452-
ggml_tensor * neck_conv_1;
453-
ggml_tensor * neck_norm_1_w;
454-
ggml_tensor * neck_norm_1_b;
449+
ggml_tensor * neck_0_w;
450+
ggml_tensor * neck_1_w;
451+
ggml_tensor * neck_1_b;
452+
ggml_tensor * neck_2_w;
453+
ggml_tensor * neck_3_w;
454+
ggml_tensor * neck_3_b;
455+
ggml_tensor * net_2;
456+
ggml_tensor * net_3;
455457

456-
std::vector<clip_layer> enc_layers;
458+
int32_t n_sam_layers = 0; // used by deepseek-ocr sam encoder
459+
460+
std::vector<clip_layer> sam_layers;
457461

458462
};
459463

@@ -683,7 +687,7 @@ struct clip_graph {
683687

684688
// loop over layers
685689
for (int il = 0; il < _depth; il++) {
686-
auto & layer = model.enc_layers[il];
690+
auto & layer = model.sam_layers[il];
687691

688692
// layernorm1
689693
cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il);
@@ -770,49 +774,45 @@ struct clip_graph {
770774
cur = ggml_win_unpart(ctx0, cur, w0, h0, 14);
771775
}
772776

773-
if (layer.ls_1_w) {
774-
cur = ggml_mul(ctx0, cur, layer.ls_1_w);
775-
cb(cur, "attn_out_scaled", il);
776-
}
777-
778777
// re-add the layer input, e.g., residual
779778
cur = ggml_add(ctx0, cur, inpL);
780779

781-
cb(cur, "ffn_inp", il);
780+
ggml_tensor * inpFF = cur;
781+
782+
783+
cb(inpFF, "ffn_inp", il);
782784

783785
// layernorm2
784-
cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il);
786+
cur = build_norm(inpFF, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il);
785787
cb(cur, "ffn_inp_normed", il);
786788

787789
// ffn
788-
cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, layer.ff_gate_w, layer.ff_gate_b, layer.ff_down_w,
790+
cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, nullptr, nullptr, layer.ff_down_w,
789791
layer.ff_down_b, hparams.ffn_op, il);
790792

791793
cb(cur, "ffn_out", il);
792794

793-
if (layer.ls_2_w) {
794-
cur = ggml_mul(ctx0, cur, layer.ls_2_w);
795-
cb(cur, "ffn_out_scaled", il);
796-
}
797795

798796
// residual 2
799-
cur = ggml_add(ctx0, inpL, cur);
797+
cur = ggml_add(ctx0, cur, inpFF);
800798
cb(cur, "layer_out", il);
801799

802800
return cur; // B, 1024, 16, 16
803801
}
804802

805803
cur = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 2, 0, 1, 3));
806804

807-
cur = ggml_conv_2d_sk_p0(ctx0, model.neck_conv_0, cur);
805+
cur = ggml_conv_2d_sk_p0(ctx0, model.neck_0_w, cur);
808806

809-
cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_norm_0_w, model.neck_norm_0_b, hparams.eps);
807+
cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_1_w, model.neck_1_b, hparams.eps);
810808

811-
cur = ggml_conv_2d_s1_ph(ctx0, model.neck_conv_1, cur);
809+
cur = ggml_conv_2d_s1_ph(ctx0, model.neck_2_w, cur);
812810

813-
cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_norm_1_w, model.neck_norm_1_b, hparams.eps);
811+
cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_3_w, model.neck_3_b, hparams.eps);
814812

815-
//cur = ggml_cpy(ctx0, cur, state.embd_img);
813+
//TODO : check conv padding
814+
cur = ggml_conv_2d_s1_ph(ctx0, model.net_2, cur);
815+
cur = ggml_conv_2d_s1_ph(ctx0, model.net_3, cur);
816816

817817
ggml_build_forward_expand(gf, cur);
818818
return cur;
@@ -3604,6 +3604,35 @@ struct clip_model_loader {
36043604
} break;
36053605
case PROJECTOR_TYPE_DEEPSEEK_OCR:
36063606
{
3607+
model.pos_embed = get_tensor(TN_SAM_POS_EMBD);
3608+
model.patch_embed_proj_w = get_tensor(string_format(TN_SAM_PATCH_EMBD, "weight"));
3609+
model.patch_embed_proj_b = get_tensor(string_format(TN_SAM_PATCH_EMBD, "bias"));
3610+
model.sam_layers.resize(model.n_sam_layers);
3611+
for (int il = 0; il < model.n_sam_layers; ++il) {
3612+
auto & layer = model.sam_layers[il];
3613+
layer.qkv_w = get_tensor(string_format(TN_SAM_ATTN_QKV, il, "weight"));
3614+
layer.qkv_b = get_tensor(string_format(TN_SAM_ATTN_QKV, il, "bias"));
3615+
layer.o_w = get_tensor(string_format(TN_SAM_ATTN_OUT, il, "weight"));
3616+
layer.o_b = get_tensor(string_format(TN_SAM_ATTN_OUT, il, "bias"));
3617+
layer.ln_1_w = get_tensor(string_format(TN_SAM_PRE_NORM, il, "weight"));
3618+
layer.ln_1_b = get_tensor(string_format(TN_SAM_PRE_NORM, il, "bias"));
3619+
layer.ln_2_w = get_tensor(string_format(TN_SAM_POST_NORM, il, "weight"));
3620+
layer.ln_2_b = get_tensor(string_format(TN_SAM_POST_NORM, il, "bias"));
3621+
layer.rel_pos_h = get_tensor(string_format(TN_SAM_ATTN_POS_H, il));
3622+
layer.rel_pos_w = get_tensor(string_format(TN_SAM_ATTN_POS_W, il));
3623+
layer.ff_up_w = get_tensor(string_format(TN_SAM_FFN_UP, il, "weight"));
3624+
layer.ff_up_b = get_tensor(string_format(TN_SAM_FFN_UP, il, "bias"));
3625+
layer.ff_down_w = get_tensor(string_format(TN_SAM_FFN_DOWN, il, "weight"));
3626+
layer.ff_down_b = get_tensor(string_format(TN_SAM_FFN_DOWN, il, "bias"));
3627+
}
3628+
model.neck_0_w = get_tensor(string_format(TN_SAM_NECK, 0, "weight"));
3629+
model.neck_1_b = get_tensor(string_format(TN_SAM_NECK, 1, "bias"));
3630+
model.neck_1_w = get_tensor(string_format(TN_SAM_NECK, 1, "weight"));
3631+
model.neck_2_w = get_tensor(string_format(TN_SAM_NECK, 2, "weight"));
3632+
model.neck_3_b = get_tensor(string_format(TN_SAM_NECK, 3, "bias"));
3633+
model.neck_3_w = get_tensor(string_format(TN_SAM_NECK, 3, "weight"));
3634+
model.net_2 = get_tensor(string_format(TN_SAM_NET, 2, "weight"));
3635+
model.net_3 = get_tensor(string_format(TN_SAM_NET, 3, "weight"));
36073636
}
36083637
break;
36093638
default:

0 commit comments

Comments
 (0)