@@ -816,7 +816,9 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im
816816 const auto & hparams = model.hparams ;
817817
818818 const int patch_size = hparams.patch_size ;
819- const int num_patches = ((img.nx / patch_size) * (img.ny / patch_size));
819+ const int px = img.nx / patch_size;
820+ const int py = img.ny / patch_size;
821+ const int num_patches = px * py;
820822 const int num_pos = num_patches + 1 ; // +1 for [CLS]
821823 const int hidden_size = hparams.hidden_size ;
822824 const int n_head = hparams.n_head ;
@@ -849,10 +851,9 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im
849851 ggml_tensor * inp = ggml_conv_2d (ctx0, patch_embd_view, inp_raw, patch_size, patch_size, 0 , 0 , 1 , 1 );
850852 inp = ggml_reshape_2d (ctx0, inp, num_patches, hidden_size);
851853 inp = ggml_cont (ctx0, ggml_transpose (ctx0, inp));
852- inp = ggml_add (ctx0, inp, model.patch_bias );
853854
854855 // add CLS
855- inp_raw = ggml_concat (ctx0, inp_raw , model.class_embedding , 0 );
856+ inp = ggml_concat (ctx0, inp , model.class_embedding , 1 );
856857
857858 // 2D input positions
858859 ggml_tensor * pos_h = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, num_pos);
@@ -881,31 +882,31 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im
881882 ggml_tensor * Q =
882883 ggml_add (ctx0, ggml_mul_mat (ctx0, model.layers [il].q_w , cur), model.layers [il].q_b );
883884
884- Q = ggml_reshape_3d (ctx0, Q, d_head, n_head, num_patches );
885+ Q = ggml_reshape_3d (ctx0, Q, d_head, n_head, num_pos );
885886 Q = build_rope_2d (ctx0, Q, pos_w, pos_h, hparams.rope_theta , false );
886887 Q = ggml_cont (ctx0, ggml_permute (ctx0, Q, 0 , 2 , 1 , 3 ));
887888
888889 ggml_tensor * K =
889890 ggml_add (ctx0, ggml_mul_mat (ctx0, model.layers [il].k_w , cur), model.layers [il].k_b );
890891
891- K = ggml_reshape_3d (ctx0, K, d_head, n_head, num_patches );
892+ K = ggml_reshape_3d (ctx0, K, d_head, n_head, num_pos );
892893 K = build_rope_2d (ctx0, K, pos_w, pos_h, hparams.rope_theta , false );
893894 K = ggml_cont (ctx0, ggml_permute (ctx0, K, 0 , 2 , 1 , 3 ));
894895
895896 ggml_tensor * V =
896897 ggml_add (ctx0, ggml_mul_mat (ctx0, model.layers [il].v_w , cur), model.layers [il].v_b );
897898
898- V = ggml_reshape_3d (ctx0, V, d_head, n_head, num_patches );
899+ V = ggml_reshape_3d (ctx0, V, d_head, n_head, num_pos );
899900 V = ggml_cont (ctx0, ggml_permute (ctx0, V, 1 , 2 , 0 , 3 ));
900901
901902 ggml_tensor * KQ = ggml_mul_mat (ctx0, K, Q);
902903 KQ = ggml_soft_max_ext (ctx0, KQ, nullptr , 1 .0f / sqrtf ((float )d_head), 0 .0f );
903904
904905 ggml_tensor * KQV = ggml_mul_mat (ctx0, V, KQ);
905- KQV = ggml_reshape_3d (ctx0, KQV, d_head, num_patches , n_head);
906+ KQV = ggml_reshape_3d (ctx0, KQV, d_head, num_pos , n_head);
906907 KQV = ggml_permute (ctx0, KQV, 0 , 2 , 1 , 3 );
907908
908- cur = ggml_cont_2d (ctx0, KQV, hidden_size, num_patches );
909+ cur = ggml_cont_2d (ctx0, KQV, hidden_size, num_pos );
909910 }
910911
911912 // attention output
@@ -922,8 +923,8 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im
922923 cur = ggml_add (ctx0, ggml_mul (ctx0, cur, model.layers [il].ln_2_w ), model.layers [il].ln_2_b );
923924 }
924925
925- cur = ggml_mul_mat (ctx0, model.layers [il].ff_i_w , cur);
926- cur = ggml_add (ctx0, cur, model.layers [il].ff_i_b );
926+ cur = ggml_mul_mat (ctx0, model.layers [il].ff_up_w , cur);
927+ cur = ggml_add (ctx0, cur, model.layers [il].ff_up_b );
927928
928929 if (ctx->use_silu ) {
929930 cur = ggml_silu (ctx0, cur);
@@ -933,8 +934,8 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im
933934 GGML_ABORT (" llama4: Unsupported activation" );
934935 }
935936
936- cur = ggml_mul_mat (ctx0, model.layers [il].ff_o_w , cur);
937- cur = ggml_add (ctx0, cur, model.layers [il].ff_o_b );
937+ cur = ggml_mul_mat (ctx0, model.layers [il].ff_down_w , cur);
938+ cur = ggml_add (ctx0, cur, model.layers [il].ff_down_b );
938939
939940 // residual 2
940941 cur = ggml_add (ctx0, embeddings, cur);
@@ -950,33 +951,43 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im
950951 embeddings = ggml_add (ctx0, ggml_mul (ctx0, embeddings, model.post_ln_w ), model.post_ln_b );
951952 }
952953
953- // Llama4VisionPixelShuffleMLP
954+ // based on Llama4VisionPixelShuffleMLP
955+ // https://github.com/huggingface/transformers/blob/2932f318a20d9e54cc7aea052e040164d85de7d6/src/transformers/models/llama4/modeling_llama4.py#L1151
954956 {
955957 ggml_tensor * cur = embeddings;
958+ const int batch_size = 1 ; // always 1 for now since we don't support batching
956959 const int scale_factor = model.hparams .proj_scale_factor ;
957- const int n_embd = cur->ne [0 ];
958- const int seq = cur->ne [1 ];
959- const int bsz = 1 ; // batch size, always 1 for now since we don't support batching
960- const int height = std::sqrt (seq);
961- const int width = std::sqrt (seq);
962960 GGML_ASSERT (scale_factor != 0 );
963- cur = ggml_reshape_4d (ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz);
961+
962+ // remove CLS by doing a view
963+ cur = ggml_view_3d (ctx0, cur,
964+ hidden_size, num_patches, batch_size,
965+ ggml_row_size (cur->type , hidden_size),
966+ ggml_row_size (cur->type , hidden_size * num_patches), 0 );
967+
968+ cur = ggml_reshape_3d (ctx0, cur,
969+ hidden_size * scale_factor,
970+ num_patches / scale_factor,
971+ batch_size);
964972 cur = ggml_permute (ctx0, cur, 0 , 2 , 1 , 3 );
973+
965974 cur = ggml_reshape_4d (ctx0, ggml_cont (ctx0, cur),
966- n_embd * scale_factor * scale_factor,
967- height / scale_factor,
968- width / scale_factor,
969- bsz );
975+ hidden_size * scale_factor * scale_factor,
976+ py / scale_factor,
977+ px / scale_factor,
978+ batch_size );
970979 cur = ggml_permute (ctx0, cur, 0 , 2 , 1 , 3 );
971- cur = ggml_reshape_3d (ctx0, ggml_cont (ctx0, cur),
972- n_embd * scale_factor * scale_factor,
973- seq / (scale_factor * scale_factor),
974- bsz);
975980
976- cur = ggml_mul_mat (ctx0, model.projection , cur);
981+ // based on Llama4VisionMLP2 (always uses GELU activation, no bias)
982+ cur = ggml_mul_mat (ctx0, model.mm_model_mlp_1_w , cur);
983+ cur = ggml_gelu (ctx0, cur);
984+ cur = ggml_mul_mat (ctx0, model.mm_model_mlp_2_w , cur);
977985 embeddings = cur;
978986 }
979987
988+ // based on Llama4MultiModalProjector
989+ embeddings = ggml_mul_mat (ctx0, model.mm_model_proj , embeddings);
990+
980991 // build the graph
981992 ggml_build_forward_expand (gf, embeddings);
982993
@@ -3135,6 +3146,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
31353146 const auto & params = ctx->vision_model .hparams ;
31363147
31373148 int n_patches = (params.image_size / params.patch_size ) * (params.image_size / params.patch_size );
3149+ int scale_factor = ctx->vision_model .hparams .proj_scale_factor ;
31383150
31393151 if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
31403152 n_patches /= 4 ;
@@ -3158,8 +3170,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
31583170 n_patches = x_patch * y_patch;
31593171 } else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
31603172 n_patches = 256 ;
3161- } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 || ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
3162- n_patches /= ctx->vision_model .hparams .proj_scale_factor ;
3173+ } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
3174+ n_patches /= scale_factor;
3175+ } else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
3176+ n_patches /= (scale_factor * scale_factor);
31633177 } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
31643178 int n_merge = ctx->vision_model .hparams .spatial_merge_size ;
31653179 int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1 );
@@ -3757,8 +3771,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
37573771 case PROJECTOR_TYPE_GEMMA3:
37583772 return ctx->vision_model .mm_input_proj_w ->ne [0 ];
37593773 case PROJECTOR_TYPE_IDEFICS3:
3760- case PROJECTOR_TYPE_LLAMA4:
37613774 return ctx->vision_model .projection ->ne [1 ];
3775+ case PROJECTOR_TYPE_LLAMA4:
3776+ return ctx->vision_model .mm_model_proj ->ne [1 ];
37623777 default :
37633778 GGML_ABORT (" Unknown projector type" );
37643779 }
0 commit comments