33import logging
44import torch
55import gguf
6+ import re
7+ import os
68
79from .ops import GGMLTensor
810from .dequant import is_quantized , dequantize_tensor
911
1012IMG_ARCH_LIST = {"flux" , "sd1" , "sdxl" , "sd3" , "aura" , "hidream" , "cosmos" , "ltxv" , "hyvid" , "wan" , "lumina2" , "qwen_image" }
1113TXT_ARCH_LIST = {"t5" , "t5encoder" , "llama" , "qwen2vl" }
14+ VIS_TYPE_LIST = {"clip-vision" }
1215
1316def get_orig_shape (reader , tensor_name ):
1417 field_key = f"comfy.gguf.orig_shape.{ tensor_name } "
@@ -70,6 +73,7 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
7073 # detect and verify architecture
7174 compat = None
7275 arch_str = get_field (reader , "general.architecture" , str )
76+ type_str = get_field (reader , "general.type" , str )
7377 if arch_str in [None , "pig" ]:
7478 if is_text_model :
7579 raise ValueError (f"This text model is incompatible with llama.cpp!\n Consider using the safetensors version\n ({ path } )" )
@@ -81,7 +85,8 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
8185 except Exception as e :
8286 raise ValueError (f"This model is not currently supported - ({ e } )" )
8387 elif arch_str not in TXT_ARCH_LIST and is_text_model :
84- raise ValueError (f"Unexpected text model architecture type in GGUF file: { arch_str !r} " )
88+ if type_str not in VIS_TYPE_LIST :
89+ raise ValueError (f"Unexpected text model architecture type in GGUF file: { arch_str !r} " )
8590 elif arch_str not in IMG_ARCH_LIST and not is_text_model :
8691 raise ValueError (f"Unexpected architecture type in GGUF file: { arch_str !r} " )
8792
@@ -165,6 +170,19 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
165170 "output.weight" : "lm_head.weight" ,
166171}
167172
173+ CLIP_VISION_SD_MAP = {
174+ "mm." : "visual.merger.mlp." ,
175+ "v.post_ln." : "visual.merger.ln_q." ,
176+ "v.patch_embd" : "visual.patch_embed.proj" ,
177+ "v.blk." : "visual.blocks." ,
178+ "ffn_up" : "mlp.up_proj" ,
179+ "ffn_down" : "mlp.down_proj" ,
180+ "ffn_gate" : "mlp.gate_proj" ,
181+ "attn_out." : "attn.proj." ,
182+ "ln1." : "norm1." ,
183+ "ln2." : "norm2." ,
184+ }
185+
168186def sd_map_replace (raw_sd , key_map ):
169187 sd = {}
170188 for k ,v in raw_sd .items ():
@@ -185,6 +203,79 @@ def llama_permute(raw_sd, n_head, n_head_kv):
185203 sd [k ] = v
186204 return sd
187205
206+ def strip_quant_suffix (name ):
207+ pattern = r"[-_]?(?:ud-)?i?q[0-9]_[a-z0-9_\-]{1,8}$"
208+ match = re .search (pattern , name , re .IGNORECASE )
209+ if match :
210+ name = name [:match .start ()]
211+ return name
212+
213+ def gguf_mmproj_loader (path ):
214+ # Reverse version of Qwen2VLVisionModel.modify_tensors
215+ logging .info ("Attenpting to find mmproj file for text encoder..." )
216+
217+ # get name to match w/o quant suffix
218+ tenc_fname = os .path .basename (path )
219+ tenc = os .path .splitext (tenc_fname )[0 ].lower ()
220+ tenc = strip_quant_suffix (tenc )
221+
222+ # try and find matching mmproj
223+ target = []
224+ root = os .path .dirname (path )
225+ for fname in os .listdir (root ):
226+ name , ext = os .path .splitext (fname )
227+ if ext .lower () != ".gguf" :
228+ continue
229+ if "mmproj" not in name .lower ():
230+ continue
231+ if tenc in name .lower ():
232+ target .append (fname )
233+
234+ if len (target ) == 0 :
235+ logging .error (f"Error: Can't find mmproj file for '{ tenc_fname } ' (matching:'{ tenc } ')! Qwen-Image-Edit will be broken!" )
236+ return {}
237+ if len (target ) > 1 :
238+ logging .error (f"Ambiguous mmproj for text encoder '{ tenc_fname } ', will use first match." )
239+
240+ logging .info (f"Using mmproj '{ target [0 ]} ' for text encoder '{ tenc_fname } '." )
241+ target = os .path .join (root , target [0 ])
242+ vsd = gguf_sd_loader (target , is_text_model = True )
243+
244+ # concat 4D to 5D
245+ if "v.patch_embd.weight.1" in vsd :
246+ w1 = dequantize_tensor (vsd .pop ("v.patch_embd.weight" ), dtype = torch .float32 )
247+ w2 = dequantize_tensor (vsd .pop ("v.patch_embd.weight.1" ), dtype = torch .float32 )
248+ vsd ["v.patch_embd.weight" ] = torch .stack ([w1 , w2 ], dim = 2 )
249+
250+ # run main replacement
251+ vsd = sd_map_replace (vsd , CLIP_VISION_SD_MAP )
252+
253+ # handle split Q/K/V
254+ if "visual.blocks.0.attn_q.weight" in vsd :
255+ attns = {}
256+ # filter out attentions + group
257+ for k ,v in vsd .items ():
258+ if any (x in k for x in ["attn_q" , "attn_k" , "attn_v" ]):
259+ k_attn , k_name = k .rsplit (".attn_" , 1 )
260+ k_attn += ".attn.qkv." + k_name .split ("." )[- 1 ]
261+ if k_attn not in attns :
262+ attns [k_attn ] = {}
263+ attns [k_attn ][k_name ] = dequantize_tensor (
264+ v , dtype = (torch .bfloat16 if is_quantized (v ) else torch .float16 )
265+ )
266+
267+ # recombine
268+ for k ,v in attns .items ():
269+ suffix = k .split ("." )[- 1 ]
270+ vsd [k ] = torch .cat ([
271+ v [f"q.{ suffix } " ],
272+ v [f"k.{ suffix } " ],
273+ v [f"v.{ suffix } " ],
274+ ], dim = 0 )
275+ del attns
276+
277+ return vsd
278+
188279def gguf_tokenizer_loader (path , temb_shape ):
189280 # convert gguf tokenizer to spiece
190281 logging .info ("Attempting to recreate sentencepiece tokenizer from GGUF file metadata..." )
@@ -254,6 +345,9 @@ def gguf_clip_loader(path):
254345 sd = sd_map_replace (sd , LLAMA_SD_MAP )
255346 if arch == "llama" :
256347 sd = llama_permute (sd , 32 , 8 ) # L3
348+ if arch == "qwen2vl" :
349+ vsd = gguf_mmproj_loader (path )
350+ sd .update (vsd )
257351 else :
258352 pass
259353 return sd
0 commit comments