@@ -172,6 +172,7 @@ struct clip_hparams {
172172 std::unordered_set<int32_t > vision_feature_layer;
173173 int32_t attn_window_size = 0 ;
174174 int32_t n_wa_pattern = 0 ;
175+ int32_t spatial_merge_size = 0 ;
175176};
176177
177178struct clip_layer {
@@ -232,6 +233,7 @@ struct clip_vision_model {
232233 struct ggml_tensor * projection;
233234
234235 // LLaVA projection
236+ struct ggml_tensor * mm_input_norm_w = nullptr ;
235237 struct ggml_tensor * mm_0_w = nullptr ;
236238 struct ggml_tensor * mm_0_b = nullptr ;
237239 struct ggml_tensor * mm_2_w = nullptr ;
@@ -311,6 +313,7 @@ struct clip_vision_model {
311313
312314 // pixtral
313315 struct ggml_tensor * token_embd_img_break = nullptr ;
316+ struct ggml_tensor * mm_patch_merger_w = nullptr ;
314317};
315318
316319struct clip_ctx {
@@ -721,7 +724,13 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
721724 {
722725 ggml_tensor * gate_proj = ggml_mul_mat (ctx0, model.layers [il].ff_gate_w , cur);
723726 ggml_tensor * up_proj = ggml_mul_mat (ctx0, model.layers [il].ff_up_w , cur);
724- gate_proj = ggml_silu (ctx0, gate_proj); // pixtral uses silu
727+ if (ctx->use_silu ) {
728+ gate_proj = ggml_silu (ctx0, gate_proj);
729+ } else if (ctx->use_gelu ) {
730+ gate_proj = ggml_gelu (ctx0, gate_proj);
731+ } else {
732+ GGML_ABORT (" Pixtral: Unsupported activation" );
733+ }
725734 cur = ggml_mul (ctx0, up_proj, gate_proj);
726735 cur = ggml_mul_mat (ctx0, model.layers [il].ff_down_w , cur);
727736 }
@@ -732,7 +741,18 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
732741 embeddings = cur;
733742 }
734743
735- // LlavaMultiModalProjector (with GELU activation)
744+ // mistral small 3.1 patch merger
745+ // ref: https://github.com/huggingface/transformers/blob/7a3e208892c06a5e278144eaf38c8599a42f53e7/src/transformers/models/mistral3/modeling_mistral3.py#L67
746+ if (model.mm_patch_merger_w ) {
747+ GGML_ASSERT (hparams.spatial_merge_size > 0 );
748+ embeddings = ggml_mul (ctx0, ggml_rms_norm (ctx0, embeddings, eps), model.mm_input_norm_w );
749+
750+ // reshape image tokens to 2D grid
751+ embeddings = ggml_reshape_3d (ctx0, embeddings, hidden_size, n_patches_x, n_patches_y);
752+ embeddings = ggml_permute (ctx0, embeddings, 1 , 2 , 0 , 3 ); // [x, y, hidden_size]
753+ }
754+
755+ // LlavaMultiModalProjector (always using GELU activation)
736756 {
737757 embeddings = ggml_mul_mat (ctx0, model.mm_1_w , embeddings);
738758 embeddings = ggml_add (ctx0, embeddings, model.mm_1_b );
@@ -1734,6 +1754,7 @@ struct clip_model_loader {
17341754 case PROJECTOR_TYPE_PIXTRAL:
17351755 {
17361756 hparams.rope_theta = 10000 .0f ;
1757+ get_u32 (KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size , false );
17371758 } break ;
17381759 case PROJECTOR_TYPE_QWEN25VL:
17391760 {
@@ -1962,6 +1983,9 @@ struct clip_model_loader {
19621983 vision_model.mm_2_b = get_tensor (string_format (TN_LLAVA_PROJ, 2 , " bias" ));
19631984 // [IMG_BREAK] token embedding
19641985 vision_model.token_embd_img_break = get_tensor (TN_TOK_IMG_BREAK);
1986+ // for mistral small 3.1
1987+ vision_model.mm_input_norm_w = get_tensor (TN_MM_INP_NORM, false );
1988+ vision_model.mm_patch_merger_w = get_tensor (TN_MM_PATCH_MERGER, false );
19651989 } break ;
19661990 default :
19671991 GGML_ASSERT (false && " unknown projector type" );
0 commit comments