@@ -194,10 +194,6 @@ struct clip_hparams {
194194};
195195
196196struct clip_layer {
197- // layernorm 1 (input norm)
198- struct ggml_tensor * ln_1_w = nullptr ;
199- struct ggml_tensor * ln_1_b = nullptr ;
200-
201197 // attention
202198 ggml_tensor * k_w = nullptr ;
203199 ggml_tensor * k_b = nullptr ;
@@ -526,7 +522,7 @@ struct clip_graph {
526522 ggml_set_input (pos_w);
527523
528524 auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
529- return build_rope_2d (ctx0, cur, pos_h, pos_w, hparams.rope_theta );
525+ return build_rope_2d (ctx0, cur, pos_h, pos_w, hparams.rope_theta , true );
530526 };
531527
532528 ggml_tensor * inp = build_inp ();
@@ -940,6 +936,90 @@ struct clip_graph {
940936 return gf;
941937 }
942938
939+ ggml_cgraph * build_llama4 () {
940+ GGML_ASSERT (model.class_embedding != nullptr );
941+ GGML_ASSERT (model.position_embeddings != nullptr );
942+
943+ const int n_pos = n_patches + 1 ;
944+
945+ // 2D input positions
946+ ggml_tensor * pos_h = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_pos);
947+ ggml_set_name (pos_h, " pos_h" );
948+ ggml_set_input (pos_h);
949+
950+ ggml_tensor * pos_w = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_pos);
951+ ggml_set_name (pos_w, " pos_w" );
952+ ggml_set_input (pos_w);
953+
954+ ggml_tensor * inp = build_inp_raw ();
955+
956+ // Llama4UnfoldConvolution
957+ {
958+ inp = ggml_conv_2d (ctx0, model.patch_embeddings_0 , inp, patch_size, patch_size, 0 , 0 , 1 , 1 );
959+ inp = ggml_reshape_2d (ctx0, inp, n_patches, n_embd);
960+ inp = ggml_cont (ctx0, ggml_transpose (ctx0, inp));
961+ cb (inp, " patch_conv" , -1 );
962+ inp = ggml_add (ctx0, inp, model.patch_bias );
963+ cb (inp, " patch_bias" , -1 );
964+ }
965+
966+ // add CLS token
967+ inp = ggml_concat (ctx0, inp, model.class_embedding , 1 );
968+
969+ // build ViT with 2D position embeddings
970+ auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
971+ return build_rope_2d (ctx0, cur, pos_w, pos_h, hparams.rope_theta , false );
972+ };
973+ ggml_tensor * cur = build_vit (
974+ inp, n_pos,
975+ NORM_TYPE_NORMAL,
976+ hparams.ffn_op ,
977+ model.position_embeddings ,
978+ add_pos);
979+
980+ // remove CLS token
981+ cur = ggml_view_2d (ctx0, cur,
982+ n_embd, n_patches,
983+ ggml_row_size (cur->type , n_embd), 0 );
984+
985+ // pixel shuffle
986+ // based on Llama4VisionPixelShuffleMLP
987+ // https://github.com/huggingface/transformers/blob/2932f318a20d9e54cc7aea052e040164d85de7d6/src/transformers/models/llama4/modeling_llama4.py#L1151
988+ {
989+ const int scale_factor = model.hparams .proj_scale_factor ;
990+ const int bsz = 1 ; // batch size, always 1 for now since we don't support batching
991+ const int height = n_patches_y;
992+ const int width = n_patches_x;
993+ GGML_ASSERT (scale_factor > 0 );
994+ GGML_ASSERT (n_patches_x == n_patches_y); // llama4 only supports square images
995+ cur = ggml_reshape_4d (ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz);
996+ cur = ggml_permute (ctx0, cur, 0 , 2 , 1 , 3 );
997+ cur = ggml_reshape_4d (ctx0, ggml_cont (ctx0, cur),
998+ n_embd * scale_factor * scale_factor,
999+ height / scale_factor,
1000+ width / scale_factor,
1001+ bsz);
1002+ cur = ggml_permute (ctx0, cur, 0 , 2 , 1 , 3 );
1003+ // flatten to 2D
1004+ cur = ggml_reshape_2d (ctx0, ggml_cont (ctx0, cur),
1005+ n_embd * scale_factor * scale_factor,
1006+ cur->ne [1 ] * cur->ne [2 ]);
1007+ }
1008+
1009+ // based on Llama4VisionMLP2 (always uses GELU activation, no bias)
1010+ {
1011+ cur = ggml_mul_mat (ctx0, model.mm_model_mlp_1_w , cur);
1012+ cur = ggml_gelu (ctx0, cur);
1013+ cur = ggml_mul_mat (ctx0, model.mm_model_mlp_2_w , cur);
1014+ cur = ggml_gelu (ctx0, cur);
1015+ }
1016+
1017+ // build the graph
1018+ ggml_build_forward_expand (gf, cur);
1019+
1020+ return gf;
1021+ }
1022+
9431023 // this graph is used by llava, granite and glm
9441024 // due to having embedding_stack (used by granite), we cannot reuse build_vit
9451025 ggml_cgraph * build_llava () {
@@ -1634,9 +1714,10 @@ struct clip_graph {
16341714 static ggml_tensor * build_rope_2d (
16351715 ggml_context * ctx0,
16361716 ggml_tensor * cur,
1637- ggml_tensor * pos_h,
1638- ggml_tensor * pos_w,
1639- const float freq_base
1717+ ggml_tensor * pos_a, // first half
1718+ ggml_tensor * pos_b, // second half
1719+ const float freq_base,
1720+ const bool interleave_freq
16401721 ) {
16411722 const int64_t n_dim = cur->ne [0 ];
16421723 const int64_t n_head = cur->ne [1 ];
@@ -1650,7 +1731,9 @@ struct clip_graph {
16501731 // ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
16511732 // then for the second half, we use freq_scale to shift the inv_freq
16521733 // ^ why? replace (2i) with (2i+1) in the above equation
1653- const float freq_scale_odd = std::pow (freq_base, (float )-2 /n_dim);
1734+ const float freq_scale_odd = interleave_freq
1735+ ? std::pow (freq_base, (float )-2 /n_dim)
1736+ : 1.0 ;
16541737
16551738 // first half
16561739 ggml_tensor * first;
@@ -1663,7 +1746,7 @@ struct clip_graph {
16631746 first = ggml_rope_ext (
16641747 ctx0,
16651748 first,
1666- pos_h , // positions
1749+ pos_a , // positions
16671750 nullptr , // freq factors
16681751 n_dim/2 , // n_dims
16691752 0 , 0 , freq_base,
@@ -1683,7 +1766,7 @@ struct clip_graph {
16831766 second = ggml_rope_ext (
16841767 ctx0,
16851768 second,
1686- pos_w , // positions
1769+ pos_b , // positions
16871770 nullptr , // freq factors
16881771 n_dim/2 , // n_dims
16891772 0 , 0 , freq_base,
0 commit comments