@@ -3843,7 +3843,43 @@ def set_gguf_parameters(self):
38433843 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
38443844 # process the experts separately
38453845 name = name .replace ("language_model." , "" ) # InternVL
3846- if name .startswith ("mlp" ) or name .startswith ("vision_model" ) or name .startswith ("model.vision_tower" ) or name .startswith ("model.multi_modal_projector" ):
3846+
3847+ # handle aggregated expert tensors
3848+ # GGUF stores dimensions reversed from PyTorch, so:
3849+ # PyTorch (A,B,C) -> GGUF writes [C,B,A] -> GGML reads ne={C,B,A}
3850+ # Input shapes from HF: (n_expert, n_ff_exp, n_embd) or (n_expert, n_embd, n_ff_exp)
3851+ # Expected GGML ne: {n_embd, n_ff_exp, n_expert} for gate/up, {n_ff_exp, n_embd, n_expert} for down
3852+ if name .endswith ("mlp.experts.down_proj" ) or name .endswith ("mlp.experts.down_proj.weight" ):
3853+ mapped = f"{ name } .weight" if not name .endswith (".weight" ) else name
3854+ # Input: (n_expert=128, n_ff_exp=768, n_embd=2048)
3855+ # Want GGML ne: {n_ff_exp, n_embd, n_expert} = {768, 2048, 128}
3856+ # Need PyTorch: (128, 2048, 768) [reversed of GGML]
3857+ # So: permute(0, 2, 1): (128, 768, 2048) -> (128, 2048, 768)
3858+ permuted = data_torch .permute (0 , 2 , 1 ).contiguous ()
3859+ return [(self .map_tensor_name (mapped ), permuted )]
3860+
3861+ if name .endswith ("mlp.experts.gate_up_proj" ) or name .endswith ("mlp.experts.gate_up_proj.weight" ):
3862+ if data_torch .ndim < 3 or data_torch .shape [- 1 ] % 2 != 0 :
3863+ raise ValueError (f"Unexpected gate_up_proj shape for { name } : { tuple (data_torch .shape )} " )
3864+ split_dim = data_torch .shape [- 1 ] // 2
3865+ gate = data_torch [..., :split_dim ].contiguous ()
3866+ up = data_torch [..., split_dim :].contiguous ()
3867+ # Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768)
3868+ # Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128}
3869+ # Need PyTorch: (128, 768, 2048) [reversed of GGML]
3870+ # So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048)
3871+ base_name = name .removesuffix (".weight" )
3872+ base = base_name .rsplit ('.' , 1 )[0 ]
3873+ mapped_gate = f"{ base } .gate_proj.weight"
3874+ mapped_up = f"{ base } .up_proj.weight"
3875+ perm_gate = gate .permute (0 , 2 , 1 ).contiguous ()
3876+ perm_up = up .permute (0 , 2 , 1 ).contiguous ()
3877+ return [
3878+ (self .map_tensor_name (mapped_gate ), perm_gate ),
3879+ (self .map_tensor_name (mapped_up ), perm_up ),
3880+ ]
3881+
3882+ if name .startswith ("mlp" ) or name .startswith ("vision_model" ) or name .startswith ("model.vision_tower" ) or name .startswith ("model.multi_modal_projector" ) or name .startswith ("model.visual" ):
38473883 # skip visual tensors
38483884 return []
38493885 if name .find ("experts" ) != - 1 :
@@ -3991,6 +4027,201 @@ def set_vocab(self):
39914027 super ().set_vocab ()
39924028
39934029
4030+ @ModelBase .register ("Qwen3VLForConditionalGeneration" , "Qwen3VLMoeForConditionalGeneration" )
4031+ class Qwen3VLVisionModel (MmprojModel ):
4032+ def __init__ (self , * args , ** kwargs ):
4033+ super ().__init__ (* args , ** kwargs )
4034+ assert self .hparams_vision is not None
4035+ # Compute image_size if not present
4036+ if "image_size" not in self .hparams_vision :
4037+ # For Qwen3VL/Qwen3VLMoe, compute from num_position_embeddings
4038+ num_pos = self .hparams_vision .get ("num_position_embeddings" , 2304 )
4039+ patch_size = self .hparams_vision .get ("patch_size" , 16 )
4040+ # num_position_embeddings = (image_size / patch_size) ** 2
4041+ # So image_size = sqrt(num_position_embeddings) * patch_size
4042+ image_size = int (num_pos ** 0.5 * patch_size )
4043+ self .hparams_vision ["image_size" ] = image_size
4044+
4045+ # Rename config values for compatibility
4046+ self .hparams_vision ["num_attention_heads" ] = self .hparams_vision .get ("num_heads" )
4047+ self .hparams_vision ["num_hidden_layers" ] = self .hparams_vision .get ("depth" )
4048+
4049+ self .deepstack_layers : list [int ] = list (self .hparams_vision .get ("deepstack_visual_indexes" , []))
4050+
4051+ def set_gguf_parameters (self ):
4052+ super ().set_gguf_parameters ()
4053+ self .gguf_writer .add_clip_projector_type (gguf .VisionProjectorType .QWEN3VL )
4054+ self .gguf_writer .add_vision_use_gelu (True )
4055+
4056+ if self .hparams_vision is not None :
4057+ merge_size = self .hparams_vision .get ("spatial_merge_size" )
4058+ if merge_size is not None :
4059+ self .gguf_writer .add_vision_spatial_merge_size (int (merge_size ))
4060+
4061+ # Use text config's rms_norm_eps for vision attention layernorm eps
4062+ rms_norm_eps = self .global_config .get ("text_config" , {}).get ("rms_norm_eps" , 1e-6 )
4063+ self .gguf_writer .add_vision_attention_layernorm_eps (rms_norm_eps )
4064+
4065+ if self .deepstack_layers :
4066+ self .gguf_writer .add_vision_deepstack_layers (self .deepstack_layers )
4067+
4068+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
4069+ # Skip text model tensors - they go in the text model file
4070+ if name .startswith ("model.language_model." ) or name .startswith ("lm_head." ):
4071+ return []
4072+
4073+ if name .startswith ("model.visual." ):
4074+ name = name .replace ("model.visual." , "visual." , 1 )
4075+
4076+ if name .startswith ("visual.deepstack_merger_list." ):
4077+ prefix , rest = name .split ("." , maxsplit = 3 )[2 :]
4078+ idx = int (prefix )
4079+ target = rest
4080+
4081+ tensor_type : gguf .MODEL_TENSOR
4082+ if target .startswith ("norm." ):
4083+ tensor_type = gguf .MODEL_TENSOR .V_DS_NORM
4084+ suffix = target .split ("." , 1 )[1 ]
4085+ elif target .startswith ("linear_fc1." ):
4086+ tensor_type = gguf .MODEL_TENSOR .V_DS_FC1
4087+ suffix = target .split ("." , 1 )[1 ]
4088+ elif target .startswith ("linear_fc2." ):
4089+ tensor_type = gguf .MODEL_TENSOR .V_DS_FC2
4090+ suffix = target .split ("." , 1 )[1 ]
4091+ else :
4092+ raise ValueError (f"Unexpected deepstack tensor: { name } " )
4093+
4094+ new_name = self .format_tensor_name (tensor_type , idx , suffix = f".{ suffix } " )
4095+ return [(new_name , data_torch )]
4096+
4097+ if name .startswith ("visual.merger." ):
4098+ suffix = name .split ("." , 2 )[2 ]
4099+ if suffix .startswith ("linear_fc" ):
4100+ fc_idx_str , tail = suffix .split ("." , 1 )
4101+ fc_num = int (fc_idx_str .replace ("linear_fc" , "" ))
4102+ # Qwen3VL has linear_fc1 and linear_fc2
4103+ # Map to indices 0 and 2 (matching Qwen2VL which uses indices 0 and 2)
4104+ if fc_num == 1 :
4105+ fc_idx = 0
4106+ elif fc_num == 2 :
4107+ fc_idx = 2
4108+ else :
4109+ raise ValueError (f"unexpected fc index { fc_num } in { name } " )
4110+ new_name = self .format_tensor_name (gguf .MODEL_TENSOR .V_MMPROJ , fc_idx , suffix = f".{ tail } " )
4111+ elif suffix .startswith ("norm." ):
4112+ new_name = self .format_tensor_name (gguf .MODEL_TENSOR .V_POST_NORM , suffix = f".{ suffix .split ('.' , 1 )[1 ]} " )
4113+ else :
4114+ raise ValueError (f"Unexpected merger tensor: { name } " )
4115+ return [(new_name , data_torch )]
4116+
4117+ if name == "visual.patch_embed.proj.weight" :
4118+ # split Conv3D into Conv2Ds along temporal dimension
4119+ c1 , c2 , kt , _ , _ = data_torch .shape
4120+ del c1 , c2
4121+ if kt != 2 :
4122+ raise ValueError ("Current implementation only supports temporal_patch_size of 2" )
4123+ return [
4124+ (gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .V_ENC_EMBD_PATCH ] + ".weight" , data_torch [:, :, 0 , ...]),
4125+ (gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .V_ENC_EMBD_PATCH ] + ".weight.1" , data_torch [:, :, 1 , ...]),
4126+ ]
4127+
4128+ if name == "visual.patch_embed.proj.bias" :
4129+ # Include the bias - it's used by the C++ code
4130+ return [(gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .V_ENC_EMBD_PATCH ] + ".bias" , data_torch )]
4131+
4132+ if name .startswith ("visual." ):
4133+ if ".qkv." in name :
4134+ if data_torch .ndim == 2 :
4135+ c3 , _ = data_torch .shape
4136+ else :
4137+ c3 = data_torch .shape [0 ]
4138+ if c3 % 3 != 0 :
4139+ raise ValueError (f"Unexpected QKV shape for { name } : { data_torch .shape } " )
4140+ c = c3 // 3
4141+ wq = data_torch [:c ]
4142+ wk = data_torch [c : c * 2 ]
4143+ wv = data_torch [c * 2 :]
4144+ base = name .replace ("qkv" , "{placeholder}" )
4145+ return [
4146+ (self .map_tensor_name (base .format (placeholder = "q" )), wq ),
4147+ (self .map_tensor_name (base .format (placeholder = "k" )), wk ),
4148+ (self .map_tensor_name (base .format (placeholder = "v" )), wv ),
4149+ ]
4150+
4151+ return [(self .map_tensor_name (name ), data_torch )]
4152+
4153+ # Fall back to parent class for other tensors
4154+ return super ().modify_tensors (data_torch , name , bid )
4155+
4156+
4157+ @ModelBase .register ("Qwen3VLForConditionalGeneration" )
4158+ class Qwen3VLTextModel (Qwen3Model ):
4159+ model_arch = gguf .MODEL_ARCH .QWEN3VL
4160+
4161+ def set_gguf_parameters (self ):
4162+ super ().set_gguf_parameters ()
4163+
4164+ # Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
4165+ text_config = self .hparams .get ("text_config" , {})
4166+ # rope_scaling is deprecated in V5, use rope_parameters instead
4167+ rope_scaling = text_config .get ("rope_scaling" ) or text_config .get ("rope_parameters" ) or {}
4168+
4169+ if rope_scaling .get ("mrope_section" ):
4170+ # mrope_section contains [time, height, width] dimensions
4171+ mrope_section = rope_scaling ["mrope_section" ]
4172+ # Pad to 4 dimensions [time, height, width, extra]
4173+ while len (mrope_section ) < 4 :
4174+ mrope_section .append (0 )
4175+ self .gguf_writer .add_rope_dimension_sections (mrope_section [:4 ])
4176+
4177+ logger .info (f"MRoPE sections: { mrope_section [:4 ]} " )
4178+
4179+ vision_config = self .hparams .get ("vision_config" , {})
4180+ deepstack_layer_num = len (vision_config .get ("deepstack_visual_indexes" , []))
4181+ self .gguf_writer .add_num_deepstack_layers (deepstack_layer_num )
4182+
4183+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
4184+ # Skip vision tensors - they go in the mmproj file
4185+ if name .startswith ("model.visual." ):
4186+ return []
4187+
4188+ return super ().modify_tensors (data_torch , name , bid )
4189+
4190+
4191+ @ModelBase .register ("Qwen3VLMoeForConditionalGeneration" )
4192+ class Qwen3VLMoeTextModel (Qwen3MoeModel ):
4193+ model_arch = gguf .MODEL_ARCH .QWEN3VLMOE
4194+
4195+ def set_gguf_parameters (self ):
4196+ super ().set_gguf_parameters ()
4197+
4198+ # Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
4199+ text_config = self .hparams .get ("text_config" , {})
4200+ # rope_scaling is deprecated in V5, use rope_parameters instead
4201+ rope_scaling = text_config .get ("rope_scaling" ) or text_config .get ("rope_parameters" ) or {}
4202+
4203+ if rope_scaling .get ("mrope_section" ):
4204+ # mrope_section contains [time, height, width] dimensions
4205+ mrope_section = rope_scaling ["mrope_section" ]
4206+ # Pad to 4 dimensions [time, height, width, extra]
4207+ while len (mrope_section ) < 4 :
4208+ mrope_section .append (0 )
4209+ self .gguf_writer .add_rope_dimension_sections (mrope_section [:4 ])
4210+
4211+ logger .info (f"MRoPE sections: { mrope_section [:4 ]} " )
4212+
4213+ vision_config = self .hparams .get ("vision_config" , {})
4214+ deepstack_layer_num = len (vision_config .get ("deepstack_visual_indexes" , []))
4215+ self .gguf_writer .add_num_deepstack_layers (deepstack_layer_num )
4216+
4217+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
4218+ # Skip vision tensors - they go in the mmproj file
4219+ if name .startswith ("model.visual." ):
4220+ return []
4221+
4222+ return super ().modify_tensors (data_torch , name , bid )
4223+
4224+
39944225@ModelBase .register ("GPT2LMHeadModel" )
39954226class GPT2Model (TextModel ):
39964227 model_arch = gguf .MODEL_ARCH .GPT2
0 commit comments