@@ -172,6 +172,7 @@ struct clip_hparams {
172172 std::unordered_set<int32_t > vision_feature_layer;
173173 int32_t attn_window_size = 0 ;
174174 int32_t n_wa_pattern = 0 ;
175+ int32_t spatial_merge_size = 0 ;
175176};
176177
177178struct clip_layer {
@@ -232,6 +233,7 @@ struct clip_vision_model {
232233 struct ggml_tensor * projection;
233234
234235 // LLaVA projection
236+ struct ggml_tensor * mm_input_norm_w = nullptr ;
235237 struct ggml_tensor * mm_0_w = nullptr ;
236238 struct ggml_tensor * mm_0_b = nullptr ;
237239 struct ggml_tensor * mm_2_w = nullptr ;
@@ -311,6 +313,7 @@ struct clip_vision_model {
311313
312314 // pixtral
313315 struct ggml_tensor * token_embd_img_break = nullptr ;
316+ struct ggml_tensor * mm_patch_merger_w = nullptr ;
314317};
315318
316319struct clip_ctx {
@@ -637,6 +640,7 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
637640 const int d_head = hidden_size / n_head;
638641 const int n_layer = hparams.n_layer ;
639642 const float eps = hparams.eps ;
643+ const int n_merge = hparams.spatial_merge_size ;
640644
641645 struct ggml_init_params params = {
642646 /* .mem_size =*/ ctx->buf_compute_meta .size (),
@@ -721,7 +725,13 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
721725 {
722726 ggml_tensor * gate_proj = ggml_mul_mat (ctx0, model.layers [il].ff_gate_w , cur);
723727 ggml_tensor * up_proj = ggml_mul_mat (ctx0, model.layers [il].ff_up_w , cur);
724- gate_proj = ggml_silu (ctx0, gate_proj); // pixtral uses silu
728+ if (ctx->use_silu ) {
729+ gate_proj = ggml_silu (ctx0, gate_proj);
730+ } else if (ctx->use_gelu ) {
731+ gate_proj = ggml_gelu (ctx0, gate_proj);
732+ } else {
733+ GGML_ABORT (" Pixtral: Unsupported activation" );
734+ }
725735 cur = ggml_mul (ctx0, up_proj, gate_proj);
726736 cur = ggml_mul_mat (ctx0, model.layers [il].ff_down_w , cur);
727737 }
@@ -732,14 +742,42 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
732742 embeddings = cur;
733743 }
734744
735- // LlavaMultiModalProjector (with GELU activation)
745+ // mistral small 3.1 patch merger
746+ // ref: https://github.com/huggingface/transformers/blob/7a3e208892c06a5e278144eaf38c8599a42f53e7/src/transformers/models/mistral3/modeling_mistral3.py#L67
747+ if (model.mm_patch_merger_w ) {
748+ GGML_ASSERT (hparams.spatial_merge_size > 0 );
749+
750+ ggml_tensor * cur = embeddings;
751+ cur = ggml_mul (ctx0, ggml_rms_norm (ctx0, cur, eps), model.mm_input_norm_w );
752+
753+ // reshape image tokens to 2D grid
754+ cur = ggml_reshape_3d (ctx0, cur, hidden_size, n_patches_x, n_patches_y);
755+ cur = ggml_permute (ctx0, cur, 2 , 0 , 1 , 3 ); // [x, y, hidden_size]
756+ cur = ggml_cont (ctx0, cur);
757+
758+ // torch.nn.functional.unfold is just an im2col under the hood
759+ // we just need a dummy kernel to make it work
760+ ggml_tensor * kernel = ggml_view_3d (ctx0, cur, n_merge, n_merge, cur->ne [2 ], 0 , 0 , 0 );
761+ cur = ggml_im2col (ctx0, kernel, cur, n_merge, n_merge, 0 , 0 , 1 , 1 , true , inp->type );
762+
763+ // project to hidden_size
764+ cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ], cur->ne [1 ] * cur->ne [2 ]);
765+ cur = ggml_mul_mat (ctx0, model.mm_patch_merger_w , cur);
766+ embeddings = cur;
767+ }
768+
769+ // LlavaMultiModalProjector (always using GELU activation)
736770 {
737771 embeddings = ggml_mul_mat (ctx0, model.mm_1_w , embeddings);
738- embeddings = ggml_add (ctx0, embeddings, model.mm_1_b );
772+ if (model.mm_1_b ) {
773+ embeddings = ggml_add (ctx0, embeddings, model.mm_1_b );
774+ }
739775
740776 embeddings = ggml_gelu (ctx0, embeddings);
741777 embeddings = ggml_mul_mat (ctx0, model.mm_2_w , embeddings);
742- embeddings = ggml_add (ctx0, embeddings, model.mm_2_b );
778+ if (model.mm_2_b ) {
779+ embeddings = ggml_add (ctx0, embeddings, model.mm_2_b );
780+ }
743781 }
744782
745783 // arrangement of the [IMG_BREAK] token
@@ -749,11 +787,14 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
749787 // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
750788 // after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows]
751789
790+ const int p_y = n_merge > 0 ? n_patches_y / n_merge : n_patches_y;
791+ const int p_x = n_merge > 0 ? n_patches_x / n_merge : n_patches_x;
792+ const int p_total = p_x * p_y;
752793 const int n_embd_text = embeddings->ne [0 ];
753- const int n_tokens_output = num_patches + n_patches_y - 1 ; // one [IMG_BREAK] per row, except the last row
794+ const int n_tokens_output = p_total + p_y - 1 ; // one [IMG_BREAK] per row, except the last row
754795
755- ggml_tensor * cur = ggml_reshape_3d (ctx0, embeddings, n_embd_text, n_patches_x, n_patches_y );
756- ggml_tensor * tok = ggml_new_tensor_3d (ctx0, embeddings->type , n_embd_text, 1 , n_patches_y );
796+ ggml_tensor * cur = ggml_reshape_3d (ctx0, embeddings, n_embd_text, p_x, p_y );
797+ ggml_tensor * tok = ggml_new_tensor_3d (ctx0, embeddings->type , n_embd_text, 1 , p_y );
757798 tok = ggml_scale (ctx0, tok, 0.0 ); // clear the tensor
758799 tok = ggml_add (ctx0, tok, model.token_embd_img_break );
759800 cur = ggml_concat (ctx0, cur, tok, 1 );
@@ -1734,6 +1775,7 @@ struct clip_model_loader {
17341775 case PROJECTOR_TYPE_PIXTRAL:
17351776 {
17361777 hparams.rope_theta = 10000 .0f ;
1778+ get_u32 (KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size , false );
17371779 } break ;
17381780 case PROJECTOR_TYPE_QWEN25VL:
17391781 {
@@ -1957,11 +1999,14 @@ struct clip_model_loader {
19571999 case PROJECTOR_TYPE_PIXTRAL:
19582000 {
19592001 vision_model.mm_1_w = get_tensor (string_format (TN_LLAVA_PROJ, 1 , " weight" ));
1960- vision_model.mm_1_b = get_tensor (string_format (TN_LLAVA_PROJ, 1 , " bias" ));
2002+ vision_model.mm_1_b = get_tensor (string_format (TN_LLAVA_PROJ, 1 , " bias" ), false );
19612003 vision_model.mm_2_w = get_tensor (string_format (TN_LLAVA_PROJ, 2 , " weight" ));
1962- vision_model.mm_2_b = get_tensor (string_format (TN_LLAVA_PROJ, 2 , " bias" ));
2004+ vision_model.mm_2_b = get_tensor (string_format (TN_LLAVA_PROJ, 2 , " bias" ), false );
19632005 // [IMG_BREAK] token embedding
19642006 vision_model.token_embd_img_break = get_tensor (TN_TOK_IMG_BREAK);
2007+ // for mistral small 3.1
2008+ vision_model.mm_input_norm_w = get_tensor (TN_MM_INP_NORM, false );
2009+ vision_model.mm_patch_merger_w = get_tensor (TN_MM_PATCH_MERGER, false );
19652010 } break ;
19662011 default :
19672012 GGML_ASSERT (false && " unknown projector type" );
@@ -2516,7 +2561,7 @@ struct llava_uhd {
25162561
25172562 // no pinpoints, dynamically calculate the grid size (e.g. minicpmv)
25182563
2519- auto best_size = get_best_resize (original_size, slice_size, patch_size, has_slices);
2564+ auto best_size = get_best_resize (original_size, slice_size, patch_size, ! has_slices);
25202565 res.overview_size = best_size;
25212566
25222567 if (!has_slices) {
@@ -2926,8 +2971,9 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
29262971 } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
29272972 n_patches /= ctx->vision_model .hparams .proj_scale_factor ;
29282973 } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
2929- int n_patches_x = img->nx / params.patch_size ;
2930- int n_patches_y = img->ny / params.patch_size ;
2974+ int n_merge = ctx->vision_model .hparams .spatial_merge_size ;
2975+ int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1 );
2976+ int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1 );
29312977 n_patches = n_patches_y*n_patches_x + n_patches_y - 1 ; // + one [IMG_BREAK] per row, except the last row
29322978 }
29332979
@@ -3484,7 +3530,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
34843530 return ctx->vision_model .mm_model_peg_0_b ->ne [0 ];
34853531 case PROJECTOR_TYPE_MLP:
34863532 case PROJECTOR_TYPE_PIXTRAL:
3487- return ctx->vision_model .mm_2_b ->ne [0 ];
3533+ return ctx->vision_model .mm_2_w ->ne [1 ];
34883534 case PROJECTOR_TYPE_MLP_NORM:
34893535 return ctx->vision_model .mm_3_b ->ne [0 ];
34903536 case PROJECTOR_TYPE_MINICPMV:
0 commit comments