Skip to content

Commit 5a8de41

Browse files
committed
Merge branch 'main' into stable
2 parents df44c9d + d247022 commit 5a8de41

File tree

1 file changed

+95
-1
lines changed

1 file changed

+95
-1
lines changed

loader.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
import logging
44
import torch
55
import gguf
6+
import re
7+
import os
68

79
from .ops import GGMLTensor
810
from .dequant import is_quantized, dequantize_tensor
911

1012
IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan", "lumina2", "qwen_image"}
1113
TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl"}
14+
VIS_TYPE_LIST = {"clip-vision"}
1215

1316
def get_orig_shape(reader, tensor_name):
1417
field_key = f"comfy.gguf.orig_shape.{tensor_name}"
@@ -70,6 +73,7 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
7073
# detect and verify architecture
7174
compat = None
7275
arch_str = get_field(reader, "general.architecture", str)
76+
type_str = get_field(reader, "general.type", str)
7377
if arch_str in [None, "pig"]:
7478
if is_text_model:
7579
raise ValueError(f"This text model is incompatible with llama.cpp!\nConsider using the safetensors version\n({path})")
@@ -81,7 +85,8 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
8185
except Exception as e:
8286
raise ValueError(f"This model is not currently supported - ({e})")
8387
elif arch_str not in TXT_ARCH_LIST and is_text_model:
84-
raise ValueError(f"Unexpected text model architecture type in GGUF file: {arch_str!r}")
88+
if type_str not in VIS_TYPE_LIST:
89+
raise ValueError(f"Unexpected text model architecture type in GGUF file: {arch_str!r}")
8590
elif arch_str not in IMG_ARCH_LIST and not is_text_model:
8691
raise ValueError(f"Unexpected architecture type in GGUF file: {arch_str!r}")
8792

@@ -165,6 +170,19 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
165170
"output.weight": "lm_head.weight",
166171
}
167172

173+
CLIP_VISION_SD_MAP = {
174+
"mm.": "visual.merger.mlp.",
175+
"v.post_ln.": "visual.merger.ln_q.",
176+
"v.patch_embd": "visual.patch_embed.proj",
177+
"v.blk.": "visual.blocks.",
178+
"ffn_up": "mlp.up_proj",
179+
"ffn_down": "mlp.down_proj",
180+
"ffn_gate": "mlp.gate_proj",
181+
"attn_out.": "attn.proj.",
182+
"ln1.": "norm1.",
183+
"ln2.": "norm2.",
184+
}
185+
168186
def sd_map_replace(raw_sd, key_map):
169187
sd = {}
170188
for k,v in raw_sd.items():
@@ -185,6 +203,79 @@ def llama_permute(raw_sd, n_head, n_head_kv):
185203
sd[k] = v
186204
return sd
187205

206+
def strip_quant_suffix(name):
207+
pattern = r"[-_]?(?:ud-)?i?q[0-9]_[a-z0-9_\-]{1,8}$"
208+
match = re.search(pattern, name, re.IGNORECASE)
209+
if match:
210+
name = name[:match.start()]
211+
return name
212+
213+
def gguf_mmproj_loader(path):
214+
# Reverse version of Qwen2VLVisionModel.modify_tensors
215+
logging.info("Attenpting to find mmproj file for text encoder...")
216+
217+
# get name to match w/o quant suffix
218+
tenc_fname = os.path.basename(path)
219+
tenc = os.path.splitext(tenc_fname)[0].lower()
220+
tenc = strip_quant_suffix(tenc)
221+
222+
# try and find matching mmproj
223+
target = []
224+
root = os.path.dirname(path)
225+
for fname in os.listdir(root):
226+
name, ext = os.path.splitext(fname)
227+
if ext.lower() != ".gguf":
228+
continue
229+
if "mmproj" not in name.lower():
230+
continue
231+
if tenc in name.lower():
232+
target.append(fname)
233+
234+
if len(target) == 0:
235+
logging.error(f"Error: Can't find mmproj file for '{tenc_fname}' (matching:'{tenc}')! Qwen-Image-Edit will be broken!")
236+
return {}
237+
if len(target) > 1:
238+
logging.error(f"Ambiguous mmproj for text encoder '{tenc_fname}', will use first match.")
239+
240+
logging.info(f"Using mmproj '{target[0]}' for text encoder '{tenc_fname}'.")
241+
target = os.path.join(root, target[0])
242+
vsd = gguf_sd_loader(target, is_text_model=True)
243+
244+
# concat 4D to 5D
245+
if "v.patch_embd.weight.1" in vsd:
246+
w1 = dequantize_tensor(vsd.pop("v.patch_embd.weight"), dtype=torch.float32)
247+
w2 = dequantize_tensor(vsd.pop("v.patch_embd.weight.1"), dtype=torch.float32)
248+
vsd["v.patch_embd.weight"] = torch.stack([w1, w2], dim=2)
249+
250+
# run main replacement
251+
vsd = sd_map_replace(vsd, CLIP_VISION_SD_MAP)
252+
253+
# handle split Q/K/V
254+
if "visual.blocks.0.attn_q.weight" in vsd:
255+
attns = {}
256+
# filter out attentions + group
257+
for k,v in vsd.items():
258+
if any(x in k for x in ["attn_q", "attn_k", "attn_v"]):
259+
k_attn, k_name = k.rsplit(".attn_", 1)
260+
k_attn += ".attn.qkv." + k_name.split(".")[-1]
261+
if k_attn not in attns:
262+
attns[k_attn] = {}
263+
attns[k_attn][k_name] = dequantize_tensor(
264+
v, dtype=(torch.bfloat16 if is_quantized(v) else torch.float16)
265+
)
266+
267+
# recombine
268+
for k,v in attns.items():
269+
suffix = k.split(".")[-1]
270+
vsd[k] = torch.cat([
271+
v[f"q.{suffix}"],
272+
v[f"k.{suffix}"],
273+
v[f"v.{suffix}"],
274+
], dim=0)
275+
del attns
276+
277+
return vsd
278+
188279
def gguf_tokenizer_loader(path, temb_shape):
189280
# convert gguf tokenizer to spiece
190281
logging.info("Attempting to recreate sentencepiece tokenizer from GGUF file metadata...")
@@ -254,6 +345,9 @@ def gguf_clip_loader(path):
254345
sd = sd_map_replace(sd, LLAMA_SD_MAP)
255346
if arch == "llama":
256347
sd = llama_permute(sd, 32, 8) # L3
348+
if arch == "qwen2vl":
349+
vsd = gguf_mmproj_loader(path)
350+
sd.update(vsd)
257351
else:
258352
pass
259353
return sd

0 commit comments

Comments
 (0)