22from typing import Dict
33
44import torch
5+ import numpy as np
56from gguf import *
67from transformers import (
78 Qwen2_5_VLForConditionalGeneration ,
89 Qwen2_5_VLProcessor ,
910 Qwen2_5_VLConfig ,
10- Qwen2VLImageProcessor
1111)
1212
1313VISION = "clip.vision"
14- MODEL_INPUT_DIR = None
1514
1615
1716def k (raw_key : str , arch : str ) -> str :
@@ -20,260 +19,157 @@ def k(raw_key: str, arch: str) -> str:
2019
2120def to_gguf_name (name : str ) -> str :
2221 og = name
23- # Handle the base case where vision_model is not in the name
24- if not name .startswith ("vision_model." ):
25- name = "vision_model." + name
26-
27- name = name .replace ("vision_model" , "v" )
28- name = name .replace ("text_model" , "t" )
29- name = name .replace ("blocks" , "blk" )
30- name = name .replace ("embeddings." , "" )
22+ name = name .replace ("text_model" , "t" ).replace ("visual" , "v" )
23+ name = name .replace ("blocks" , "blk" ).replace ("embeddings." , "" )
3124 name = name .replace ("attn." , "attn_" )
3225
33- # Handle MLP components correctly
34- name = name .replace ("mlp.gate_proj" , "ffn_gate" )
35- name = name .replace ("mlp.up_proj" , "ffn_up" )
36- name = name .replace ("mlp.down_proj" , "ffn_down" )
26+ # Handle new Qwen2.5 MLP structure
27+ if "mlp.gate_proj" in name :
28+ name = name .replace ("mlp.gate_proj" , "ffn_gate" )
29+ elif "mlp.up_proj" in name :
30+ name = name .replace ("mlp.up_proj" , "ffn_up" )
31+ elif "mlp.down_proj" in name :
32+ name = name .replace ("mlp.down_proj" , "ffn_down" )
33+ else :
34+ name = name .replace ("mlp.fc1" , "ffn_down" ).replace ("mlp.fc2" , "ffn_up" )
3735
38- # Handle projection and norm components
3936 name = name .replace ("proj." , "out." )
40- name = name .replace ("norm1" , "ln1" )
41- name = name .replace ("norm2 " , "ln2" )
37+ name = name .replace ("norm1" , "ln1" ). replace ( "norm2" , "ln2" )
38+ name = name .replace ("merger.mlp " , 'mm' )
4239
43- # Handle merger components correctly
44- name = name .replace ("merger.mlp" , "mm" )
40+ # For RMSNorm, which doesn't have bias
41+ if "weight_g" in name :
42+ name = name .replace ("weight_g" , "weight" )
4543
4644 print (f"[to_gguf_name] { og } --> { name } " )
4745 return name
4846
4947
50- def find_vision_tensors (qwen2vl , np_dtype ) -> Dict [str , np .ndarray ]:
51- vision_model = qwen2vl .visual
48+ def find_vision_tensors (model , dtype ) -> Dict [str , np .ndarray ]:
49+ visual = model .visual
5250 tensor_map = {}
5351
54- # Debug info
55- print (f"Vision model type: { type (vision_model )} " )
56- print (f"Number of blocks: { len (vision_model .blocks )} " )
57-
58- for name , ten in vision_model .state_dict ().items ():
52+ for name , ten in visual .state_dict ().items ():
5953 ten = ten .numpy ()
60-
6154 if 'qkv' in name :
62- # Split qkv tensor into q, k, v
6355 if ten .ndim == 2 : # weight
6456 c3 , _ = ten .shape
6557 else : # bias
6658 c3 = ten .shape [0 ]
67- assert c3 % 3 == 0 , f"qkv tensor shape mismatch in { name } "
59+ assert c3 % 3 == 0
6860 c = c3 // 3
6961 wq = ten [:c ]
7062 wk = ten [c : c * 2 ]
7163 wv = ten [c * 2 :]
72- base_name = to_gguf_name (name )
73- tensor_map [base_name .replace ("qkv" , "q" )] = wq
74- tensor_map [base_name .replace ("qkv" , "k" )] = wk
75- tensor_map [base_name .replace ("qkv" , "v" )] = wv
76-
77- elif 'gate_proj' in name or 'up_proj' in name or 'down_proj' in name :
78- # Handle the MLP structure with gate/up/down projections
79- tensor_map [to_gguf_name (name )] = ten
80-
64+ tensor_map [to_gguf_name (f"visual.{ name } " ).replace ("qkv" , "q" )] = wq
65+ tensor_map [to_gguf_name (f"visual.{ name } " ).replace ("qkv" , "k" )] = wk
66+ tensor_map [to_gguf_name (f"visual.{ name } " ).replace ("qkv" , "v" )] = wv
8167 elif 'merger' in name :
82- # Map merger layernorm parameters to post_ln keys
83- if name .endswith ("ln_q.weight" ):
68+ if name .endswith ("ln_q.weight_g" ):
8469 tensor_map ['v.post_ln.weight' ] = ten
85- elif name .endswith ("ln_q.bias" ):
70+ elif name .endswith ("ln_q.bias" ) and 'weight_g' not in name :
8671 tensor_map ['v.post_ln.bias' ] = ten
87- elif 'mlp' in name :
88- # Handle the merger MLP layers
89- if name .endswith ("mlp.0.weight" ) or name .endswith ("mlp.0.bias" ):
90- # First linear layer in Sequential
91- new_name = name .replace ("mlp.0" , "mm.0" )
92- tensor_map [to_gguf_name (new_name )] = ten
93- elif name .endswith ("mlp.2.weight" ) or name .endswith ("mlp.2.bias" ):
94- # Second linear layer in Sequential (after GELU)
95- new_name = name .replace ("mlp.2" , "mm.2" )
96- tensor_map [to_gguf_name (new_name )] = ten
97- else :
98- tensor_map [to_gguf_name (name )] = ten
9972 else :
73+ # "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias"
10074 tensor_map [to_gguf_name (name )] = ten
101-
10275 elif 'patch_embed.proj.weight' in name :
103- # For the Conv3d, split the temporal kernel dimension (which is 2)
76+ # NOTE: split Conv3D into Conv2Ds
10477 c1 , c2 , kt , kh , kw = ten .shape
105- assert kt == 2 , "Current implementation only supports temporal_patch_size of 2"
106-
107- # Properly handle the Conv3d weights for GGUF
108- # Reshape from [output_channels, input_channels, temporal, height, width]
109- # to the format expected by GGUF
110- # For temporal slice 0
111- tensor_map ["v.patch_embd.weight" ] = ten [:, :, 0 , :, :].reshape (c1 , c2 * kh * kw )
112- # For temporal slice 1
113- tensor_map ["v.patch_embd.weight.1" ] = ten [:, :, 1 , :, :].reshape (c1 , c2 * kh * kw )
114-
115- elif 'norm1' in name or 'norm2' in name :
116- # Handle the RMSNorm correctly
117- tensor_map [to_gguf_name (name )] = ten
118-
78+ assert kt == 2 , "Current implementation only support temporal_patch_size of 2"
79+ tensor_map ["v.patch_embd.weight" ] = ten [:, :, 0 , ...]
80+ tensor_map ["v.patch_embd.weight.1" ] = ten [:, :, 1 , ...]
11981 else :
120- tensor_map [to_gguf_name (name )] = ten
82+ tensor_map [to_gguf_name (f"visual. { name } " )] = ten
12183
122- # Ensure biases and layer norm weights remain in fp32
12384 for new_name , ten in tensor_map .items ():
124- if (ten .ndim <= 1 or
125- new_name .endswith ("ln1.weight" ) or
126- new_name .endswith ("ln1.bias" ) or
127- new_name .endswith ("ln2.weight" ) or
128- new_name .endswith ("ln2.bias" ) or
129- new_name .endswith ("post_ln.weight" ) or
130- new_name .endswith ("post_ln.bias" )):
85+ if ten .ndim <= 1 or new_name .endswith ("_norm.weight" ):
13186 tensor_map [new_name ] = ten .astype (np .float32 )
13287 else :
133- tensor_map [new_name ] = ten .astype (np_dtype )
134-
135- # Add rotary embeddings info - dummy tensor as a placeholder
136- # This is needed because the model uses rotary position embeddings
137- tensor_map ["v.position_embd.weight" ] = np .zeros ([1 , 1 ], dtype = np .float32 )
138-
88+ tensor_map [new_name ] = ten .astype (dtype )
89+ tensor_map ["v.position_embd.weight" ] = np .zeros ([10 , 10 ], dtype = np .float32 ) # dummy tensor, just here as a placeholder
13990 return tensor_map
14091
14192
14293def main (args ):
143- global MODEL_INPUT_DIR
14494 if args .data_type == 'fp32' :
14595 dtype = torch .float32
14696 np_dtype = np .float32
14797 ftype = 0
14898 elif args .data_type == 'fp16' :
149- dtype = torch .float32 # load model in fp32 then convert selected tensors to fp16
99+ dtype = torch .float32
150100 np_dtype = np .float16
151101 ftype = 1
152102 else :
153- raise ValueError ("Unsupported data type" )
103+ raise ValueError ()
154104
105+ local_model = False
155106 model_path = ""
156107 model_name = args .model_name
157108 print ("model_name: " , model_name )
158109
159- if MODEL_INPUT_DIR is not None :
160- model_path = MODEL_INPUT_DIR
161- print (f"Loading model from local directory: { model_path } " )
162- qwen2vl = Qwen2_5_VLForConditionalGeneration .from_pretrained (
163- model_path , torch_dtype = dtype
164- )
165- else :
166- print ("Loading model from Hugging Face Hub (default behavior)" )
167- qwen2vl = Qwen2_5_VLForConditionalGeneration .from_pretrained (
168- model_name , torch_dtype = dtype
169- )
170-
171- cfg : Qwen2_5_VLConfig = qwen2vl .config
110+ # Load the model with the specific Qwen2.5 class
111+ model = Qwen2_5_VLForConditionalGeneration .from_pretrained (
112+ model_name , torch_dtype = dtype , device_map = "cpu"
113+ )
114+ cfg = model .config
172115 vcfg = cfg .vision_config
173116
174- if MODEL_INPUT_DIR is not None :
175- model_name = os .path .basename (model_path .rstrip (os .sep ))
176-
117+ if os .path .isdir (model_name ):
118+ local_model = True
119+ if model_name .endswith (os .sep ):
120+ model_name = model_name [:- 1 ]
121+ model_path = model_name
122+ model_name = os .path .basename (model_name )
177123 fname_out = f"{ model_name .replace ('/' , '-' ).lower ()} -vision.gguf"
178124
179125 fout = GGUFWriter (path = fname_out , arch = "clip" )
180- fout .add_description ("Image encoder for Qwen2.5VL" )
126+ fout .add_description ("image encoder for Qwen2.5VL" )
127+
181128 fout .add_file_type (ftype )
182129 fout .add_bool ("clip.has_text_encoder" , False )
183130 fout .add_bool ("clip.has_vision_encoder" , True )
184131 fout .add_bool ("clip.has_qwen2vl_merger" , True )
132+ fout .add_bool ("clip.is_qwen2_5" , True ) # Flag to identify Qwen2.5 models
185133 fout .add_string ("clip.projector_type" , "qwen2vl_merger" )
186134
187- print (vcfg )
135+ print (cfg .vision_config )
136+ # SiLU activation
137+ fout .add_bool ("clip.use_silu" , True )
138+ fout .add_bool ("clip.use_gelu" , False )
188139
189- tensor_map = find_vision_tensors (qwen2vl , np_dtype )
140+ tensor_map = find_vision_tensors (model , np_dtype )
190141 for name , data in tensor_map .items ():
191142 fout .add_tensor (name , data )
192143
193- # Add key vision model parameters
194144 fout .add_uint32 ("clip.vision.patch_size" , vcfg .patch_size )
195- fout .add_uint32 ("clip.vision.image_size" , 560 )
196- fout .add_uint32 ("clip.vision.projection_dim" , 1536 ) # Output of the merger
197- fout .add_uint32 ("clip.vision.embedding_length " , vcfg .hidden_size )
145+ fout .add_uint32 ("clip.vision.image_size" , 14 * 40 ) # reasonable size divisible by (14*2 )
146+ fout .add_uint32 (k ( KEY_EMBEDDING_LENGTH , VISION ), vcfg . hidden_size )
147+ fout .add_uint32 ("clip.vision.projection_dim " , vcfg .hidden_size )
198148 fout .add_uint32 (k (KEY_ATTENTION_HEAD_COUNT , VISION ), vcfg .num_heads )
199- fout .add_float32 (k (KEY_ATTENTION_LAYERNORM_EPS , VISION ), 1e-6 ) # From the RMSNorm epsilon
200149 fout .add_uint32 (k (KEY_BLOCK_COUNT , VISION ), vcfg .depth )
201-
202- # For Qwen2.5VL, specify the feed forward dimension from mlp
203- fout .add_uint32 (k (KEY_FEED_FORWARD_LENGTH , VISION ), 3420 ) # From gate_proj/up_proj dimensions
204-
205- # Add additional flags for Qwen2.5 specific features
206- fout .add_bool ("clip.vision.use_rms_norm" , True ) # Qwen2 uses RMSNorm
207- fout .add_bool ("clip.vision.use_rotary_embeddings" , True ) # Uses rotary embeddings
208-
150+ fout .add_uint32 (k (KEY_FEED_FORWARD_LENGTH , VISION ), vcfg .intermediate_size )
209151 fout .add_name (model_name )
210152
211- fout .add_string ("clip.vision.mm_patch_merge_type" , "qwen2vl_merger" )
212- # Set the appropriate crop resolution based on image_size
213- fout .add_uint32 ("clip.vision.image_crop_resolution" , 560 )
214-
215- # Add image grid pinpoints to avoid buffer overflow
216- # This array defines normalized coordinates for grid sampling in the vision model
217- # Using standard grid points for 560x560 image with patch size 14
218- grid_size = 560 // 14 # Number of patches in each dimension
219- pinpoints = []
220- for y in range (grid_size ):
221- for x in range (grid_size ):
222- # Normalized coordinates from 0.0 to 1.0
223- # Convert to Python float instead of numpy.float32
224- pinpoints .append (float (x / (grid_size - 1 )))
225- pinpoints .append (float (y / (grid_size - 1 )))
226-
227- # Add pinpoints as a float array
228- fout .add_array ("clip.vision.image_grid_pinpoints" , pinpoints )
229-
230- # Load processor for image normalization values
231- if MODEL_INPUT_DIR is not None :
232- processor = Qwen2VLImageProcessor .from_pretrained (model_path )
153+ # Load the processor using the specific Qwen2.5 processor class
154+ if local_model :
155+ processor = Qwen2_5_VLProcessor .from_pretrained (model_path )
233156 else :
234157 processor = Qwen2_5_VLProcessor .from_pretrained (model_name )
235158
236- # Get the image mean and std values and ensure they're in the right format
237- try :
238- # Try accessing through image_processor first (newer versions)
239- image_mean = processor .image_mean
240- image_std = processor .image_std
241- except AttributeError :
242- # Fallback to direct access (older versions)
243- image_mean = processor .image_mean
244- image_std = processor .image_std
245-
246- # Convert numpy values to Python floats
247- image_mean = [float (x ) for x in image_mean ]
248- image_std = [float (x ) for x in image_std ]
249-
250- # Add arrays with Python float values
251- fout .add_array ("clip.vision.image_mean" , image_mean )
252- fout .add_array ("clip.vision.image_std" , image_std )
253-
254- # Set the activation function flags based on the model config
255- fout .add_bool ("clip.use_silu" , True ) # Qwen2.5VL uses SiLU activation in MLP
256- fout .add_bool ("clip.use_gelu" , False )
159+ # Get the image mean and std values from the processor
160+ fout .add_array ("clip.vision.image_mean" , processor .image_processor .image_mean )
161+ fout .add_array ("clip.vision.image_std" , processor .image_processor .image_std )
257162
258163 fout .write_header_to_file ()
259164 fout .write_kv_data_to_file ()
260165 fout .write_tensors_to_file ()
261166 fout .close ()
262- print ("Saved model as:" , fname_out )
167+ print ("save model as: " , fname_out )
263168
264169
265170if __name__ == "__main__" :
266171 parser = argparse .ArgumentParser ()
267- parser .add_argument ("-- model_name" , nargs = '?' )
172+ parser .add_argument ("model_name" , nargs = '?' , default = "Qwen/Qwen2.5-VL-3B-Instruct" )
268173 parser .add_argument ("--data_type" , nargs = '?' , choices = ['fp32' , 'fp16' ], default = "fp32" )
269- parser .add_argument ("--input_dir" , type = str , help = "Path to the local model directory" )
270174 args = parser .parse_args ()
271-
272- # Update the global MODEL_INPUT_DIR if provided.
273- if args .input_dir :
274- if os .path .isdir (args .input_dir ):
275- MODEL_INPUT_DIR = args .input_dir
276- else :
277- raise ValueError (f"Input directory not found: { args .input_dir } " )
278-
279175 main (args )
0 commit comments