@@ -3852,7 +3852,43 @@ def set_gguf_parameters(self):
38523852 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
38533853 # process the experts separately
38543854 name = name .replace ("language_model." , "" ) # InternVL
3855- if name .startswith ("mlp" ) or name .startswith ("vision_model" ) or name .startswith ("model.vision_tower" ) or name .startswith ("model.multi_modal_projector" ):
3855+
3856+ # handle aggregated expert tensors
3857+ # GGUF stores dimensions reversed from PyTorch, so:
3858+ # PyTorch (A,B,C) -> GGUF writes [C,B,A] -> GGML reads ne={C,B,A}
3859+ # Input shapes from HF: (n_expert, n_ff_exp, n_embd) or (n_expert, n_embd, n_ff_exp)
3860+ # Expected GGML ne: {n_embd, n_ff_exp, n_expert} for gate/up, {n_ff_exp, n_embd, n_expert} for down
3861+ if name .endswith ("mlp.experts.down_proj" ) or name .endswith ("mlp.experts.down_proj.weight" ):
3862+ mapped = f"{ name } .weight" if not name .endswith (".weight" ) else name
3863+ # Input: (n_expert=128, n_ff_exp=768, n_embd=2048)
3864+ # Want GGML ne: {n_ff_exp, n_embd, n_expert} = {768, 2048, 128}
3865+ # Need PyTorch: (128, 2048, 768) [reversed of GGML]
3866+ # So: permute(0, 2, 1): (128, 768, 2048) -> (128, 2048, 768)
3867+ permuted = data_torch .permute (0 , 2 , 1 ).contiguous ()
3868+ return [(self .map_tensor_name (mapped ), permuted )]
3869+
3870+ if name .endswith ("mlp.experts.gate_up_proj" ) or name .endswith ("mlp.experts.gate_up_proj.weight" ):
3871+ if data_torch .ndim < 3 or data_torch .shape [- 1 ] % 2 != 0 :
3872+ raise ValueError (f"Unexpected gate_up_proj shape for { name } : { tuple (data_torch .shape )} " )
3873+ split_dim = data_torch .shape [- 1 ] // 2
3874+ gate = data_torch [..., :split_dim ].contiguous ()
3875+ up = data_torch [..., split_dim :].contiguous ()
3876+ # Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768)
3877+ # Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128}
3878+ # Need PyTorch: (128, 768, 2048) [reversed of GGML]
3879+ # So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048)
3880+ base_name = name .removesuffix (".weight" )
3881+ base = base_name .rsplit ('.' , 1 )[0 ]
3882+ mapped_gate = f"{ base } .gate_proj.weight"
3883+ mapped_up = f"{ base } .up_proj.weight"
3884+ perm_gate = gate .permute (0 , 2 , 1 ).contiguous ()
3885+ perm_up = up .permute (0 , 2 , 1 ).contiguous ()
3886+ return [
3887+ (self .map_tensor_name (mapped_gate ), perm_gate ),
3888+ (self .map_tensor_name (mapped_up ), perm_up ),
3889+ ]
3890+
3891+ if name .startswith ("mlp" ) or name .startswith ("vision_model" ) or name .startswith ("model.vision_tower" ) or name .startswith ("model.multi_modal_projector" ) or name .startswith ("model.visual" ):
38563892 # skip visual tensors
38573893 return []
38583894 if name .find ("experts" ) != - 1 :
@@ -4004,6 +4040,187 @@ def set_vocab(self):
40044040 super ().set_vocab ()
40054041
40064042
4043+ @ModelBase .register ("Qwen3VLForConditionalGeneration" , "Qwen3VLMoeForConditionalGeneration" )
4044+ class Qwen3VLVisionModel (MmprojModel ):
4045+ def __init__ (self , * args , ** kwargs ):
4046+ super ().__init__ (* args , ** kwargs )
4047+ assert self .hparams_vision is not None
4048+ # Compute image_size if not present
4049+ if "image_size" not in self .hparams_vision :
4050+ # For Qwen3VL/Qwen3VLMoe, compute from num_position_embeddings
4051+ num_pos = self .hparams_vision .get ("num_position_embeddings" , 2304 )
4052+ patch_size = self .hparams_vision .get ("patch_size" , 16 )
4053+ # num_position_embeddings = (image_size / patch_size) ** 2
4054+ # So image_size = sqrt(num_position_embeddings) * patch_size
4055+ image_size = int (num_pos ** 0.5 * patch_size )
4056+ self .hparams_vision ["image_size" ] = image_size
4057+
4058+ # Rename config values for compatibility
4059+ self .hparams_vision ["num_attention_heads" ] = self .hparams_vision .get ("num_heads" )
4060+ self .hparams_vision ["num_hidden_layers" ] = self .hparams_vision .get ("depth" )
4061+
4062+ self .is_deepstack_layers = [False ] * int (self .hparams_vision ["num_hidden_layers" ] or 0 )
4063+ for idx in self .hparams_vision .get ("deepstack_visual_indexes" , []):
4064+ self .is_deepstack_layers [idx ] = True
4065+
4066+ def set_gguf_parameters (self ):
4067+ super ().set_gguf_parameters ()
4068+ self .gguf_writer .add_clip_projector_type (gguf .VisionProjectorType .QWEN3VL )
4069+ self .gguf_writer .add_vision_use_gelu (True )
4070+
4071+ if self .hparams_vision is not None :
4072+ merge_size = self .hparams_vision .get ("spatial_merge_size" )
4073+ if merge_size is not None :
4074+ self .gguf_writer .add_vision_spatial_merge_size (int (merge_size ))
4075+
4076+ # Use text config's rms_norm_eps for vision attention layernorm eps
4077+ rms_norm_eps = self .global_config .get ("text_config" , {}).get ("rms_norm_eps" , 1e-6 )
4078+ self .gguf_writer .add_vision_attention_layernorm_eps (rms_norm_eps )
4079+
4080+ if self .is_deepstack_layers :
4081+ self .gguf_writer .add_vision_is_deepstack_layers (self .is_deepstack_layers )
4082+
4083+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
4084+ assert self .hparams_vision is not None
4085+ # Skip text model tensors - they go in the text model file
4086+ if name .startswith ("model.language_model." ) or name .startswith ("lm_head." ):
4087+ return []
4088+
4089+ if name .startswith ("model.visual." ):
4090+ name = name .replace ("model.visual." , "visual." , 1 )
4091+
4092+ if name .startswith ("visual.deepstack_merger_list." ):
4093+ prefix , rest = name .split ("." , maxsplit = 3 )[2 :]
4094+ # prefix is the layer index, convert to absolute clip layer index!
4095+ idx = self .hparams_vision .get ("deepstack_visual_indexes" , [])[int (prefix )]
4096+ target = rest
4097+
4098+ tensor_type : gguf .MODEL_TENSOR
4099+ if target .startswith ("norm." ):
4100+ tensor_type = gguf .MODEL_TENSOR .V_DS_NORM
4101+ suffix = target .split ("." , 1 )[1 ]
4102+ elif target .startswith ("linear_fc1." ):
4103+ tensor_type = gguf .MODEL_TENSOR .V_DS_FC1
4104+ suffix = target .split ("." , 1 )[1 ]
4105+ elif target .startswith ("linear_fc2." ):
4106+ tensor_type = gguf .MODEL_TENSOR .V_DS_FC2
4107+ suffix = target .split ("." , 1 )[1 ]
4108+ else :
4109+ raise ValueError (f"Unexpected deepstack tensor: { name } " )
4110+
4111+ new_name = self .format_tensor_name (tensor_type , idx , suffix = f".{ suffix } " )
4112+ return [(new_name , data_torch )]
4113+
4114+ if name .startswith ("visual.merger." ):
4115+ suffix = name .split ("." , 2 )[2 ]
4116+ if suffix .startswith ("linear_fc" ):
4117+ fc_idx_str , tail = suffix .split ("." , 1 )
4118+ fc_num = int (fc_idx_str .replace ("linear_fc" , "" ))
4119+ # Qwen3VL has linear_fc1 and linear_fc2
4120+ # Map to indices 0 and 2 (matching Qwen2VL which uses indices 0 and 2)
4121+ if fc_num == 1 :
4122+ fc_idx = 0
4123+ elif fc_num == 2 :
4124+ fc_idx = 2
4125+ else :
4126+ raise ValueError (f"unexpected fc index { fc_num } in { name } " )
4127+ new_name = self .format_tensor_name (gguf .MODEL_TENSOR .V_MMPROJ , fc_idx , suffix = f".{ tail } " )
4128+ elif suffix .startswith ("norm." ):
4129+ new_name = self .format_tensor_name (gguf .MODEL_TENSOR .V_POST_NORM , suffix = f".{ suffix .split ('.' , 1 )[1 ]} " )
4130+ else :
4131+ raise ValueError (f"Unexpected merger tensor: { name } " )
4132+ return [(new_name , data_torch )]
4133+
4134+ if name == "visual.patch_embed.proj.weight" :
4135+ # split Conv3D into Conv2Ds along temporal dimension
4136+ c1 , c2 , kt , _ , _ = data_torch .shape
4137+ del c1 , c2
4138+ if kt != 2 :
4139+ raise ValueError ("Current implementation only supports temporal_patch_size of 2" )
4140+ return [
4141+ (gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .V_ENC_EMBD_PATCH ] + ".weight" , data_torch [:, :, 0 , ...]),
4142+ (gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .V_ENC_EMBD_PATCH ] + ".weight.1" , data_torch [:, :, 1 , ...]),
4143+ ]
4144+
4145+ if name == "visual.patch_embed.proj.bias" :
4146+ # Include the bias - it's used by the C++ code
4147+ return [(gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .V_ENC_EMBD_PATCH ] + ".bias" , data_torch )]
4148+
4149+ if name .startswith ("visual." ):
4150+ return [(self .map_tensor_name (name ), data_torch )]
4151+
4152+ # Fall back to parent class for other tensors
4153+ return super ().modify_tensors (data_torch , name , bid )
4154+
4155+
4156+ @ModelBase .register ("Qwen3VLForConditionalGeneration" )
4157+ class Qwen3VLTextModel (Qwen3Model ):
4158+ model_arch = gguf .MODEL_ARCH .QWEN3VL
4159+
4160+ def set_gguf_parameters (self ):
4161+ super ().set_gguf_parameters ()
4162+
4163+ # Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
4164+ text_config = self .hparams .get ("text_config" , {})
4165+ # rope_scaling is deprecated in V5, use rope_parameters instead
4166+ rope_scaling = text_config .get ("rope_scaling" ) or text_config .get ("rope_parameters" ) or {}
4167+
4168+ if rope_scaling .get ("mrope_section" ):
4169+ # mrope_section contains [time, height, width] dimensions
4170+ mrope_section = rope_scaling ["mrope_section" ]
4171+ # Pad to 4 dimensions [time, height, width, extra]
4172+ while len (mrope_section ) < 4 :
4173+ mrope_section .append (0 )
4174+ self .gguf_writer .add_rope_dimension_sections (mrope_section [:4 ])
4175+
4176+ logger .info (f"MRoPE sections: { mrope_section [:4 ]} " )
4177+
4178+ vision_config = self .hparams .get ("vision_config" , {})
4179+ deepstack_layer_num = len (vision_config .get ("deepstack_visual_indexes" , []))
4180+ self .gguf_writer .add_num_deepstack_layers (deepstack_layer_num )
4181+
4182+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
4183+ # Skip vision tensors - they go in the mmproj file
4184+ if name .startswith ("model.visual." ):
4185+ return []
4186+
4187+ return super ().modify_tensors (data_torch , name , bid )
4188+
4189+
4190+ @ModelBase .register ("Qwen3VLMoeForConditionalGeneration" )
4191+ class Qwen3VLMoeTextModel (Qwen3MoeModel ):
4192+ model_arch = gguf .MODEL_ARCH .QWEN3VLMOE
4193+
4194+ def set_gguf_parameters (self ):
4195+ super ().set_gguf_parameters ()
4196+
4197+ # Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
4198+ text_config = self .hparams .get ("text_config" , {})
4199+ # rope_scaling is deprecated in V5, use rope_parameters instead
4200+ rope_scaling = text_config .get ("rope_scaling" ) or text_config .get ("rope_parameters" ) or {}
4201+
4202+ if rope_scaling .get ("mrope_section" ):
4203+ # mrope_section contains [time, height, width] dimensions
4204+ mrope_section = rope_scaling ["mrope_section" ]
4205+ # Pad to 4 dimensions [time, height, width, extra]
4206+ while len (mrope_section ) < 4 :
4207+ mrope_section .append (0 )
4208+ self .gguf_writer .add_rope_dimension_sections (mrope_section [:4 ])
4209+
4210+ logger .info (f"MRoPE sections: { mrope_section [:4 ]} " )
4211+
4212+ vision_config = self .hparams .get ("vision_config" , {})
4213+ deepstack_layer_num = len (vision_config .get ("deepstack_visual_indexes" , []))
4214+ self .gguf_writer .add_num_deepstack_layers (deepstack_layer_num )
4215+
4216+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
4217+ # Skip vision tensors - they go in the mmproj file
4218+ if name .startswith ("model.visual." ):
4219+ return []
4220+
4221+ return super ().modify_tensors (data_torch , name , bid )
4222+
4223+
40074224@ModelBase .register ("GPT2LMHeadModel" )
40084225class GPT2Model (TextModel ):
40094226 model_arch = gguf .MODEL_ARCH .GPT2
0 commit comments