@@ -176,6 +176,10 @@ struct clip_hparams {
176176};
177177
178178struct clip_layer {
179+ // layernorm 1 (input norm)
180+ struct ggml_tensor * ln_1_w = nullptr ;
181+ struct ggml_tensor * ln_1_b = nullptr ;
182+
179183 // attention
180184 struct ggml_tensor * k_w = nullptr ;
181185 struct ggml_tensor * k_b = nullptr ;
@@ -187,29 +191,28 @@ struct clip_layer {
187191 struct ggml_tensor * o_w = nullptr ;
188192 struct ggml_tensor * o_b = nullptr ;
189193
190- // layernorm 1
191- struct ggml_tensor * ln_1_w = nullptr ;
192- struct ggml_tensor * ln_1_b = nullptr ;
194+ // layernorm 2 (post-attn norm / pre-ffn norm)
195+ struct ggml_tensor * ln_2_w = nullptr ;
196+ struct ggml_tensor * ln_2_b = nullptr ;
193197
194198 // ff
195199 struct ggml_tensor * ff_i_w = nullptr ; // legacy naming
196200 struct ggml_tensor * ff_i_b = nullptr ; // legacy naming
197201 struct ggml_tensor * ff_o_w = nullptr ; // legacy naming
198202 struct ggml_tensor * ff_o_b = nullptr ; // legacy naming
203+ struct ggml_tensor * ff_g_w = nullptr ; // legacy naming
204+ struct ggml_tensor * ff_g_b = nullptr ; // legacy naming
199205
200- struct ggml_tensor * ff_up_w = nullptr ;
201- struct ggml_tensor * ff_up_b = nullptr ;
206+ struct ggml_tensor * ff_up_w = nullptr ;
207+ struct ggml_tensor * ff_up_b = nullptr ;
202208 struct ggml_tensor * ff_gate_w = nullptr ;
203209 struct ggml_tensor * ff_gate_b = nullptr ;
204210 struct ggml_tensor * ff_down_w = nullptr ;
205211 struct ggml_tensor * ff_down_b = nullptr ;
206212
207- struct ggml_tensor * ff_g_w = NULL ;
208- struct ggml_tensor * ff_g_b = NULL ;
209-
210- // layernorm 2
211- struct ggml_tensor * ln_2_w = nullptr ;
212- struct ggml_tensor * ln_2_b = nullptr ;
213+ // post-ffn norm (output layer norm)
214+ struct ggml_tensor * post_ffn_norm_w = nullptr ;
215+ struct ggml_tensor * post_ffn_norm_b = nullptr ;
213216};
214217
215218struct clip_vision_model {
@@ -560,9 +563,10 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
560563static ggml_tensor * build_rope_2d (
561564 ggml_context * ctx0,
562565 ggml_tensor * cur,
563- ggml_tensor * pos_h,
564- ggml_tensor * pos_w,
565- const float freq_base
566+ ggml_tensor * pos_a, // first half
567+ ggml_tensor * pos_b, // second half
568+ const float freq_base,
569+ const bool interleave_freq
566570) {
567571 const int64_t n_dim = cur->ne [0 ];
568572 const int64_t n_head = cur->ne [1 ];
@@ -576,7 +580,9 @@ static ggml_tensor * build_rope_2d(
576580 // ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
577581 // then for the second half, we use freq_scale to shift the inv_freq
578582 // ^ why? replace (2i) with (2i+1) in the above equation
579- const float freq_scale_odd = std::pow (freq_base, (float )-2 /n_dim);
583+ const float freq_scale_odd = interleave_freq
584+ ? std::pow (freq_base, (float )-2 /n_dim)
585+ : 1.0 ;
580586
581587 // first half
582588 ggml_tensor * first;
@@ -589,7 +595,7 @@ static ggml_tensor * build_rope_2d(
589595 first = ggml_rope_ext (
590596 ctx0,
591597 first,
592- pos_h , // positions
598+ pos_a , // positions
593599 nullptr , // freq factors
594600 n_dim/2 , // n_dims
595601 0 , 0 , freq_base,
@@ -609,7 +615,7 @@ static ggml_tensor * build_rope_2d(
609615 second = ggml_rope_ext (
610616 ctx0,
611617 second,
612- pos_w , // positions
618+ pos_b , // positions
613619 nullptr , // freq factors
614620 n_dim/2 , // n_dims
615621 0 , 0 , freq_base,
@@ -687,13 +693,13 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
687693 struct ggml_tensor * Q = ggml_mul_mat (ctx0, model.layers [il].q_w , cur);
688694
689695 Q = ggml_reshape_3d (ctx0, Q, d_head, n_head, num_patches);
690- Q = build_rope_2d (ctx0, Q, pos_h, pos_w, hparams.rope_theta );
696+ Q = build_rope_2d (ctx0, Q, pos_h, pos_w, hparams.rope_theta , true );
691697 Q = ggml_cont (ctx0, ggml_permute (ctx0, Q, 0 , 2 , 1 , 3 ));
692698
693699 struct ggml_tensor * K = ggml_mul_mat (ctx0, model.layers [il].k_w , cur);
694700
695701 K = ggml_reshape_3d (ctx0, K, d_head, n_head, num_patches);
696- K = build_rope_2d (ctx0, K, pos_h, pos_w, hparams.rope_theta );
702+ K = build_rope_2d (ctx0, K, pos_h, pos_w, hparams.rope_theta , true );
697703 K = ggml_cont (ctx0, ggml_permute (ctx0, K, 0 , 2 , 1 , 3 ));
698704
699705 struct ggml_tensor * V = ggml_mul_mat (ctx0, model.layers [il].v_w , cur);
@@ -809,6 +815,174 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
809815 return gf;
810816}
811817
818+ static ggml_cgraph * clip_image_build_graph_llama4 (clip_ctx * ctx, const clip_image_f32 & img) {
819+ const auto & model = ctx->vision_model ;
820+ const auto & hparams = model.hparams ;
821+
822+ const int patch_size = hparams.patch_size ;
823+ const int num_patches = ((img.nx / patch_size) * (img.ny / patch_size));
824+ const int hidden_size = hparams.hidden_size ;
825+ const int n_head = hparams.n_head ;
826+ const int d_head = hidden_size / n_head;
827+ const int n_layer = hparams.n_layer ;
828+ const float eps = hparams.eps ;
829+
830+ struct ggml_init_params params = {
831+ /* .mem_size =*/ ctx->buf_compute_meta .size (),
832+ /* .mem_buffer =*/ ctx->buf_compute_meta .data (),
833+ /* .no_alloc =*/ true ,
834+ };
835+
836+ ggml_context_ptr ctx0_ptr (ggml_init (params));
837+ auto ctx0 = ctx0_ptr.get ();
838+
839+ struct ggml_cgraph * gf = ggml_new_graph (ctx0);
840+
841+ // input raw
842+ struct ggml_tensor * inp_raw = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, img.nx , img.ny , 3 );
843+ ggml_set_name (inp_raw, " inp_raw" );
844+ ggml_set_input (inp_raw);
845+
846+ // 2D input positions
847+ struct ggml_tensor * pos_h = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, num_patches);
848+ ggml_set_name (pos_h, " pos_h" );
849+ ggml_set_input (pos_h);
850+ struct ggml_tensor * pos_w = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, num_patches);
851+ ggml_set_name (pos_w, " pos_w" );
852+ ggml_set_input (pos_w);
853+
854+ struct ggml_tensor * inp = ggml_conv_2d (ctx0, model.patch_embeddings_0 , inp_raw, patch_size, patch_size, 0 , 0 , 1 , 1 );
855+ inp = ggml_reshape_2d (ctx0, inp, num_patches, hidden_size);
856+ inp = ggml_cont (ctx0, ggml_transpose (ctx0, inp));
857+ inp = ggml_add (ctx0, inp, model.patch_bias );
858+
859+ // position embeddings
860+ struct ggml_tensor * embeddings = ggml_add (ctx0, inp, model.position_embeddings );
861+
862+ // loop over layers
863+ for (int il = 0 ; il < n_layer; il++) {
864+ struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
865+
866+ // layernorm1
867+ {
868+ cur = ggml_norm (ctx0, cur, eps);
869+ cur = ggml_add (ctx0, ggml_mul (ctx0, cur, model.layers [il].ln_1_w ), model.layers [il].ln_1_b );
870+ }
871+
872+ // self-attention
873+ {
874+
875+ struct ggml_tensor * Q =
876+ ggml_add (ctx0, ggml_mul_mat (ctx0, model.layers [il].q_w , cur), model.layers [il].q_b );
877+
878+ Q = ggml_reshape_3d (ctx0, Q, d_head, n_head, num_patches);
879+ Q = build_rope_2d (ctx0, Q, pos_w, pos_h, hparams.rope_theta , false );
880+ Q = ggml_cont (ctx0, ggml_permute (ctx0, Q, 0 , 2 , 1 , 3 ));
881+
882+ struct ggml_tensor * K =
883+ ggml_add (ctx0, ggml_mul_mat (ctx0, model.layers [il].k_w , cur), model.layers [il].k_b );
884+
885+ K = ggml_reshape_3d (ctx0, K, d_head, n_head, num_patches);
886+ K = build_rope_2d (ctx0, K, pos_w, pos_h, hparams.rope_theta , false );
887+ K = ggml_cont (ctx0, ggml_permute (ctx0, K, 0 , 2 , 1 , 3 ));
888+
889+ struct ggml_tensor * V =
890+ ggml_add (ctx0, ggml_mul_mat (ctx0, model.layers [il].v_w , cur), model.layers [il].v_b );
891+
892+ V = ggml_reshape_3d (ctx0, V, d_head, n_head, num_patches);
893+ V = ggml_cont (ctx0, ggml_permute (ctx0, V, 1 , 2 , 0 , 3 ));
894+
895+ struct ggml_tensor * KQ = ggml_mul_mat (ctx0, K, Q);
896+ KQ = ggml_soft_max_ext (ctx0, KQ, nullptr , 1 .0f / sqrtf ((float )d_head), 0 .0f );
897+
898+ struct ggml_tensor * KQV = ggml_mul_mat (ctx0, V, KQ);
899+ KQV = ggml_reshape_3d (ctx0, KQV, d_head, num_patches, n_head);
900+ KQV = ggml_permute (ctx0, KQV, 0 , 2 , 1 , 3 );
901+
902+ cur = ggml_cont_2d (ctx0, KQV, hidden_size, num_patches);
903+ }
904+
905+ // attention output
906+ cur = ggml_add (ctx0, ggml_mul_mat (ctx0, model.layers [il].o_w , cur), model.layers [il].o_b );
907+
908+ // re-add the layer input, e.g., residual
909+ cur = ggml_add (ctx0, cur, embeddings);
910+
911+ embeddings = cur; // embeddings = residual, cur = hidden_states
912+
913+ // layernorm2
914+ {
915+ cur = ggml_norm (ctx0, cur, eps);
916+ cur = ggml_add (ctx0, ggml_mul (ctx0, cur, model.layers [il].ln_2_w ), model.layers [il].ln_2_b );
917+ }
918+
919+ cur = ggml_mul_mat (ctx0, model.layers [il].ff_i_w , cur);
920+ cur = ggml_add (ctx0, cur, model.layers [il].ff_i_b );
921+
922+ if (ctx->use_silu ) {
923+ cur = ggml_silu (ctx0, cur);
924+ } else if (ctx->use_gelu ) {
925+ cur = ggml_gelu (ctx0, cur);
926+ } else {
927+ GGML_ABORT (" llama4: Unsupported activation" );
928+ }
929+
930+ cur = ggml_mul_mat (ctx0, model.layers [il].ff_o_w , cur);
931+ cur = ggml_add (ctx0, cur, model.layers [il].ff_o_b );
932+
933+ // residual 2
934+ cur = ggml_add (ctx0, embeddings, cur);
935+
936+ // norm output
937+ {
938+ cur = ggml_norm (ctx0, cur, eps);
939+ cur = ggml_add (ctx0, ggml_mul (ctx0, cur, model.layers [il].post_ffn_norm_w ), model.layers [il].post_ffn_norm_b );
940+ }
941+
942+ embeddings = cur;
943+ }
944+
945+ // post-layernorm
946+ if (model.post_ln_w ) {
947+ embeddings = ggml_norm (ctx0, embeddings, eps);
948+ ggml_set_name (embeddings, " post_ln" );
949+
950+ embeddings = ggml_add (ctx0, ggml_mul (ctx0, embeddings, model.post_ln_w ), model.post_ln_b );
951+ }
952+
953+ // Llama4VisionPixelShuffleMLP
954+ {
955+ ggml_tensor * cur = embeddings;
956+ const int scale_factor = model.hparams .proj_scale_factor ;
957+ const int n_embd = cur->ne [0 ];
958+ const int seq = cur->ne [1 ];
959+ const int bsz = 1 ; // batch size, always 1 for now since we don't support batching
960+ const int height = std::sqrt (seq);
961+ const int width = std::sqrt (seq);
962+ GGML_ASSERT (scale_factor != 0 );
963+ cur = ggml_reshape_4d (ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz);
964+ cur = ggml_permute (ctx0, cur, 0 , 2 , 1 , 3 );
965+ cur = ggml_reshape_4d (ctx0, ggml_cont (ctx0, cur),
966+ n_embd * scale_factor * scale_factor,
967+ height / scale_factor,
968+ width / scale_factor,
969+ bsz);
970+ cur = ggml_permute (ctx0, cur, 0 , 2 , 1 , 3 );
971+ cur = ggml_reshape_3d (ctx0, ggml_cont (ctx0, cur),
972+ n_embd * scale_factor * scale_factor,
973+ seq / (scale_factor * scale_factor),
974+ bsz);
975+
976+ cur = ggml_mul_mat (ctx0, model.projection , cur);
977+ embeddings = cur;
978+ }
979+
980+ // build the graph
981+ ggml_build_forward_expand (gf, embeddings);
982+
983+ return gf;
984+ }
985+
812986static ggml_cgraph * clip_image_build_graph_qwen25vl (clip_ctx * ctx, const clip_image_f32_batch & imgs) {
813987 const auto & model = ctx->vision_model ;
814988 const auto & hparams = model.hparams ;
@@ -1599,6 +1773,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
15991773 {
16001774 res = clip_image_build_graph_qwen25vl (ctx, imgs);
16011775 } break ;
1776+ case PROJECTOR_TYPE_LLAMA4:
1777+ {
1778+ res = clip_image_build_graph_llama4 (ctx, *imgs.entries [0 ]);
1779+ } break ;
16021780 default :
16031781 {
16041782 // TODO: we should have one build_* function per model
@@ -1781,6 +1959,10 @@ struct clip_model_loader {
17811959 {
17821960 get_u32 (KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern );
17831961 } break ;
1962+ case PROJECTOR_TYPE_LLAMA4:
1963+ {
1964+ get_u32 (KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor );
1965+ } break ;
17841966 default :
17851967 break ;
17861968 }
@@ -1867,6 +2049,9 @@ struct clip_model_loader {
18672049 layer.ln_1_b = get_tensor (string_format (TN_LN_1, " v" , il, " bias" ), false );
18682050 layer.ln_2_b = get_tensor (string_format (TN_LN_2, " v" , il, " bias" ), false );
18692051
2052+ layer.post_ffn_norm_b = get_tensor (string_format (TN_FFN_POST_NORM, " v" , il, " bias" ), false );
2053+ layer.post_ffn_norm_w = get_tensor (string_format (TN_FFN_POST_NORM, " v" , il, " weight" ), false );
2054+
18702055 // new naming
18712056 layer.ff_up_w = get_tensor (string_format (TN_FFN_UP, " v" , il, " weight" ));
18722057 layer.ff_up_b = get_tensor (string_format (TN_FFN_UP, " v" , il, " bias" ), false );
@@ -2008,6 +2193,12 @@ struct clip_model_loader {
20082193 vision_model.mm_input_norm_w = get_tensor (TN_MM_INP_NORM, false );
20092194 vision_model.mm_patch_merger_w = get_tensor (TN_MM_PATCH_MERGER, false );
20102195 } break ;
2196+ case PROJECTOR_TYPE_LLAMA4:
2197+ {
2198+ vision_model.mm_model_proj = get_tensor (TN_MM_PROJECTOR);
2199+ vision_model.mm_model_mlp_1_w = get_tensor (string_format (TN_MVLM_PROJ_MLP, 1 , " weight" ));
2200+ vision_model.mm_model_mlp_2_w = get_tensor (string_format (TN_MVLM_PROJ_MLP, 2 , " weight" ));
2201+ } break ;
20112202 default :
20122203 GGML_ASSERT (false && " unknown projector type" );
20132204 }
@@ -2796,7 +2987,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
27962987 }
27972988 else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE
27982989 || ctx->proj_type == PROJECTOR_TYPE_GEMMA3
2799- || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
2990+ || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3
2991+ || ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
28002992 clip_image_u8 resized_image;
28012993 int sz = params.image_size ;
28022994 image_manipulation::resize_and_pad_image (*img, resized_image, {sz, sz});
@@ -2968,7 +3160,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
29683160 n_patches = x_patch * y_patch;
29693161 } else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
29703162 n_patches = 256 ;
2971- } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
3163+ } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 || ctx-> proj_type == PROJECTOR_TYPE_LLAMA4 ) {
29723164 n_patches /= ctx->vision_model .hparams .proj_scale_factor ;
29733165 } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
29743166 int n_merge = ctx->vision_model .hparams .spatial_merge_size ;
@@ -3550,6 +3742,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
35503742 case PROJECTOR_TYPE_GEMMA3:
35513743 return ctx->vision_model .mm_input_proj_w ->ne [0 ];
35523744 case PROJECTOR_TYPE_IDEFICS3:
3745+ case PROJECTOR_TYPE_LLAMA4:
35533746 return ctx->vision_model .projection ->ne [1 ];
35543747 default :
35553748 GGML_ABORT (" Unknown projector type" );
0 commit comments