33import logging
44import torch
55import gguf
6+ import os
67
78from .ops import GGMLTensor
89from .dequant import is_quantized , dequantize_tensor
910
1011IMG_ARCH_LIST = {"flux" , "sd1" , "sdxl" , "sd3" , "aura" , "hidream" , "cosmos" , "ltxv" , "hyvid" , "wan" , "lumina2" , "qwen_image" }
1112TXT_ARCH_LIST = {"t5" , "t5encoder" , "llama" , "qwen2vl" }
13+ VIS_TYPE_LIST = {"clip-vision" }
1214
1315def get_orig_shape (reader , tensor_name ):
1416 field_key = f"comfy.gguf.orig_shape.{ tensor_name } "
@@ -70,6 +72,7 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
7072 # detect and verify architecture
7173 compat = None
7274 arch_str = get_field (reader , "general.architecture" , str )
75+ type_str = get_field (reader , "general.type" , str )
7376 if arch_str in [None , "pig" ]:
7477 if is_text_model :
7578 raise ValueError (f"This text model is incompatible with llama.cpp!\n Consider using the safetensors version\n ({ path } )" )
@@ -81,7 +84,8 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
8184 except Exception as e :
8285 raise ValueError (f"This model is not currently supported - ({ e } )" )
8386 elif arch_str not in TXT_ARCH_LIST and is_text_model :
84- raise ValueError (f"Unexpected text model architecture type in GGUF file: { arch_str !r} " )
87+ if type_str not in VIS_TYPE_LIST :
88+ raise ValueError (f"Unexpected text model architecture type in GGUF file: { arch_str !r} " )
8589 elif arch_str not in IMG_ARCH_LIST and not is_text_model :
8690 raise ValueError (f"Unexpected architecture type in GGUF file: { arch_str !r} " )
8791
@@ -165,6 +169,19 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
165169 "output.weight" : "lm_head.weight" ,
166170}
167171
172+ CLIP_VISION_SD_MAP = {
173+ "mm." : "visual.merger.mlp." ,
174+ "v.post_ln." : "visual.merger.ln_q." ,
175+ "v.patch_embd" : "visual.patch_embed.proj" ,
176+ "v.blk." : "visual.blocks." ,
177+ "ffn_up" : "mlp.up_proj" ,
178+ "ffn_down" : "mlp.down_proj" ,
179+ "ffn_gate" : "mlp.gate_proj" ,
180+ "attn_out." : "attn.proj." ,
181+ "ln1." : "norm1." ,
182+ "ln2." : "norm2." ,
183+ }
184+
168185def sd_map_replace (raw_sd , key_map ):
169186 sd = {}
170187 for k ,v in raw_sd .items ():
@@ -185,6 +202,76 @@ def llama_permute(raw_sd, n_head, n_head_kv):
185202 sd [k ] = v
186203 return sd
187204
205+ def gguf_mmproj_loader (path ):
206+ # Reverse version of Qwen2VLVisionModel.modify_tensors
207+ logging .info ("Attenpting to find mmproj file for text encoder..." )
208+
209+ # get name to match w/o quant suffix
210+ tenc_fname = os .path .basename (path )
211+ tenc = os .path .splitext (tenc_fname )[0 ].lower ()
212+ for q in [x .name for x in gguf .GGMLQuantizationType ]:
213+ if q .lower () in tenc :
214+ tenc = tenc .rsplit (q .lower (), 1 )[0 ]
215+ break
216+ tenc = tenc [:- 1 ] # dash/underscore/etc
217+
218+ # try and find matching mmproj
219+ target = []
220+ root = os .path .dirname (path )
221+ for fname in os .listdir (root ):
222+ name , ext = os .path .splitext (fname )
223+ if ext .lower () != ".gguf" :
224+ continue
225+ if "mmproj" not in name .lower ():
226+ continue
227+ if tenc in name .lower ():
228+ target .append (fname )
229+
230+ if len (target ) == 0 :
231+ logging .error (f"Error: Can't find mmproj file for '{ tenc_fname } '! Qwen-Image-Edit will be broken!" )
232+ return {}
233+ if len (target ) > 1 :
234+ logging .error (f"Ambiguous mmproj for text encoder '{ tenc_fname } ', will use first match." )
235+
236+ logging .info (f"Using mmproj '{ target [0 ]} ' for text encoder '{ tenc_fname } '." )
237+ target = os .path .join (root , target [0 ])
238+ vsd = gguf_sd_loader (target , is_text_model = True )
239+
240+ # concat 4D to 5D
241+ if "v.patch_embd.weight.1" in vsd :
242+ w1 = dequantize_tensor (vsd .pop ("v.patch_embd.weight" ), dtype = torch .float32 )
243+ w2 = dequantize_tensor (vsd .pop ("v.patch_embd.weight.1" ), dtype = torch .float32 )
244+ vsd ["v.patch_embd.weight" ] = torch .stack ([w1 , w2 ], dim = 2 )
245+
246+ # run main replacement
247+ vsd = sd_map_replace (vsd , CLIP_VISION_SD_MAP )
248+
249+ # handle split Q/K/V
250+ if "visual.blocks.0.attn_q.weight" in vsd :
251+ attns = {}
252+ # filter out attentions + group
253+ for k ,v in vsd .items ():
254+ if any (x in k for x in ["attn_q" , "attn_k" , "attn_v" ]):
255+ k_attn , k_name = k .rsplit (".attn_" , 1 )
256+ k_attn += ".attn.qkv." + k_name .split ("." )[- 1 ]
257+ if k_attn not in attns :
258+ attns [k_attn ] = {}
259+ attns [k_attn ][k_name ] = dequantize_tensor (
260+ v , dtype = (torch .bfloat16 if is_quantized (v ) else torch .float16 )
261+ )
262+
263+ # recombine
264+ for k ,v in attns .items ():
265+ suffix = k .split ("." )[- 1 ]
266+ vsd [k ] = torch .cat ([
267+ v [f"q.{ suffix } " ],
268+ v [f"k.{ suffix } " ],
269+ v [f"v.{ suffix } " ],
270+ ], dim = 0 )
271+ del attns
272+
273+ return vsd
274+
188275def gguf_tokenizer_loader (path , temb_shape ):
189276 # convert gguf tokenizer to spiece
190277 logging .info ("Attempting to recreate sentencepiece tokenizer from GGUF file metadata..." )
@@ -254,6 +341,9 @@ def gguf_clip_loader(path):
254341 sd = sd_map_replace (sd , LLAMA_SD_MAP )
255342 if arch == "llama" :
256343 sd = llama_permute (sd , 32 , 8 ) # L3
344+ if arch == "qwen2vl" :
345+ vsd = gguf_mmproj_loader (path )
346+ sd .update (vsd )
257347 else :
258348 pass
259349 return sd
0 commit comments