@@ -940,7 +940,7 @@ struct clip_graph {
940940 GGML_ASSERT (model.class_embedding != nullptr );
941941 GGML_ASSERT (model.position_embeddings != nullptr );
942942
943- const int n_pos = n_patches + 1 ;
943+ const int n_pos = n_patches + 1 ; // +1 for [CLS]
944944
945945 // 2D input positions
946946 ggml_tensor * pos_h = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_pos);
@@ -955,17 +955,19 @@ struct clip_graph {
955955
956956 // Llama4UnfoldConvolution
957957 {
958- inp = ggml_conv_2d (ctx0, model.patch_embeddings_0 , inp, patch_size, patch_size, 0 , 0 , 1 , 1 );
959- inp = ggml_reshape_2d (ctx0, inp, n_patches, n_embd);
960- inp = ggml_cont (ctx0, ggml_transpose (ctx0, inp));
961- cb (inp, " patch_conv" , -1 );
962- inp = ggml_add (ctx0, inp, model.patch_bias );
963- cb (inp, " patch_bias" , -1 );
958+ ggml_tensor * kernel = ggml_reshape_4d (ctx0, model.patch_embeddings_0 ,
959+ patch_size, patch_size, 3 , n_embd);
960+ inp = ggml_im2col (ctx0, kernel, inp, patch_size, patch_size, 0 , 0 , 1 , 1 , true , inp->type );
961+ inp = ggml_mul_mat (ctx0, model.patch_embeddings_0 , inp);
962+ inp = ggml_reshape_2d (ctx0, inp, n_embd, n_patches);
964963 }
965964
966965 // add CLS token
967966 inp = ggml_concat (ctx0, inp, model.class_embedding , 1 );
968967
968+ // add position embeddings
969+ inp = ggml_add (ctx0, inp, model.position_embeddings );
970+
969971 // build ViT with 2D position embeddings
970972 auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
971973 return build_rope_2d (ctx0, cur, pos_w, pos_h, hparams.rope_theta , false );
@@ -988,22 +990,24 @@ struct clip_graph {
988990 {
989991 const int scale_factor = model.hparams .proj_scale_factor ;
990992 const int bsz = 1 ; // batch size, always 1 for now since we don't support batching
991- const int height = n_patches_y;
992- const int width = n_patches_x;
993993 GGML_ASSERT (scale_factor > 0 );
994994 GGML_ASSERT (n_patches_x == n_patches_y); // llama4 only supports square images
995- cur = ggml_reshape_4d (ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz);
995+ cur = ggml_reshape_4d (ctx0, cur,
996+ n_embd * scale_factor,
997+ n_patches_x / scale_factor,
998+ n_patches_y,
999+ bsz);
9961000 cur = ggml_permute (ctx0, cur, 0 , 2 , 1 , 3 );
9971001 cur = ggml_reshape_4d (ctx0, ggml_cont (ctx0, cur),
9981002 n_embd * scale_factor * scale_factor,
999- height / scale_factor,
1000- width / scale_factor,
1003+ n_patches_x / scale_factor,
1004+ n_patches_y / scale_factor,
10011005 bsz);
10021006 cur = ggml_permute (ctx0, cur, 0 , 2 , 1 , 3 );
10031007 // flatten to 2D
10041008 cur = ggml_reshape_2d (ctx0, ggml_cont (ctx0, cur),
10051009 n_embd * scale_factor * scale_factor,
1006- cur-> ne [ 1 ] * cur-> ne [ 2 ] );
1010+ n_patches / scale_factor / scale_factor );
10071011 }
10081012
10091013 // based on Llama4VisionMLP2 (always uses GELU activation, no bias)
0 commit comments