Skip to content

Commit 2ffafd5

Browse files
committed
correct
1 parent 8caeed5 commit 2ffafd5

File tree

1 file changed

+17
-13
lines changed

1 file changed

+17
-13
lines changed

tools/mtmd/clip.cpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,7 @@ struct clip_graph {
940940
GGML_ASSERT(model.class_embedding != nullptr);
941941
GGML_ASSERT(model.position_embeddings != nullptr);
942942

943-
const int n_pos = n_patches + 1;
943+
const int n_pos = n_patches + 1; // +1 for [CLS]
944944

945945
// 2D input positions
946946
ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
@@ -955,17 +955,19 @@ struct clip_graph {
955955

956956
// Llama4UnfoldConvolution
957957
{
958-
inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp, patch_size, patch_size, 0, 0, 1, 1);
959-
inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd);
960-
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
961-
cb(inp, "patch_conv", -1);
962-
inp = ggml_add(ctx0, inp, model.patch_bias);
963-
cb(inp, "patch_bias", -1);
958+
ggml_tensor * kernel = ggml_reshape_4d(ctx0, model.patch_embeddings_0,
959+
patch_size, patch_size, 3, n_embd);
960+
inp = ggml_im2col(ctx0, kernel, inp, patch_size, patch_size, 0, 0, 1, 1, true, inp->type);
961+
inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp);
962+
inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches);
964963
}
965964

966965
// add CLS token
967966
inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
968967

968+
// add position embeddings
969+
inp = ggml_add(ctx0, inp, model.position_embeddings);
970+
969971
// build ViT with 2D position embeddings
970972
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
971973
return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
@@ -988,22 +990,24 @@ struct clip_graph {
988990
{
989991
const int scale_factor = model.hparams.proj_scale_factor;
990992
const int bsz = 1; // batch size, always 1 for now since we don't support batching
991-
const int height = n_patches_y;
992-
const int width = n_patches_x;
993993
GGML_ASSERT(scale_factor > 0);
994994
GGML_ASSERT(n_patches_x == n_patches_y); // llama4 only supports square images
995-
cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz);
995+
cur = ggml_reshape_4d(ctx0, cur,
996+
n_embd * scale_factor,
997+
n_patches_x / scale_factor,
998+
n_patches_y,
999+
bsz);
9961000
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
9971001
cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
9981002
n_embd * scale_factor * scale_factor,
999-
height / scale_factor,
1000-
width / scale_factor,
1003+
n_patches_x / scale_factor,
1004+
n_patches_y / scale_factor,
10011005
bsz);
10021006
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
10031007
// flatten to 2D
10041008
cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur),
10051009
n_embd * scale_factor * scale_factor,
1006-
cur->ne[1] * cur->ne[2]);
1010+
n_patches / scale_factor / scale_factor);
10071011
}
10081012

10091013
// based on Llama4VisionMLP2 (always uses GELU activation, no bias)

0 commit comments

Comments
 (0)