Skip to content

Commit 0ca2aeb

Browse files
committed
Qwen2-VL vision support
#317
1 parent cf05733 commit 0ca2aeb

File tree

1 file changed

+91
-1
lines changed

1 file changed

+91
-1
lines changed

loader.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import logging
44
import torch
55
import gguf
6+
import os
67

78
from .ops import GGMLTensor
89
from .dequant import is_quantized, dequantize_tensor
910

1011
IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan", "lumina2", "qwen_image"}
1112
TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl"}
13+
VIS_TYPE_LIST = {"clip-vision"}
1214

1315
def get_orig_shape(reader, tensor_name):
1416
field_key = f"comfy.gguf.orig_shape.{tensor_name}"
@@ -70,6 +72,7 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
7072
# detect and verify architecture
7173
compat = None
7274
arch_str = get_field(reader, "general.architecture", str)
75+
type_str = get_field(reader, "general.type", str)
7376
if arch_str in [None, "pig"]:
7477
if is_text_model:
7578
raise ValueError(f"This text model is incompatible with llama.cpp!\nConsider using the safetensors version\n({path})")
@@ -81,7 +84,8 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
8184
except Exception as e:
8285
raise ValueError(f"This model is not currently supported - ({e})")
8386
elif arch_str not in TXT_ARCH_LIST and is_text_model:
84-
raise ValueError(f"Unexpected text model architecture type in GGUF file: {arch_str!r}")
87+
if type_str not in VIS_TYPE_LIST:
88+
raise ValueError(f"Unexpected text model architecture type in GGUF file: {arch_str!r}")
8589
elif arch_str not in IMG_ARCH_LIST and not is_text_model:
8690
raise ValueError(f"Unexpected architecture type in GGUF file: {arch_str!r}")
8791

@@ -165,6 +169,19 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
165169
"output.weight": "lm_head.weight",
166170
}
167171

172+
CLIP_VISION_SD_MAP = {
173+
"mm.": "visual.merger.mlp.",
174+
"v.post_ln.": "visual.merger.ln_q.",
175+
"v.patch_embd": "visual.patch_embed.proj",
176+
"v.blk.": "visual.blocks.",
177+
"ffn_up": "mlp.up_proj",
178+
"ffn_down": "mlp.down_proj",
179+
"ffn_gate": "mlp.gate_proj",
180+
"attn_out.": "attn.proj.",
181+
"ln1.": "norm1.",
182+
"ln2.": "norm2.",
183+
}
184+
168185
def sd_map_replace(raw_sd, key_map):
169186
sd = {}
170187
for k,v in raw_sd.items():
@@ -185,6 +202,76 @@ def llama_permute(raw_sd, n_head, n_head_kv):
185202
sd[k] = v
186203
return sd
187204

205+
def gguf_mmproj_loader(path):
206+
# Reverse version of Qwen2VLVisionModel.modify_tensors
207+
logging.info("Attenpting to find mmproj file for text encoder...")
208+
209+
# get name to match w/o quant suffix
210+
tenc_fname = os.path.basename(path)
211+
tenc = os.path.splitext(tenc_fname)[0].lower()
212+
for q in [x.name for x in gguf.GGMLQuantizationType]:
213+
if q.lower() in tenc:
214+
tenc = tenc.rsplit(q.lower(), 1)[0]
215+
break
216+
tenc = tenc[:-1] # dash/underscore/etc
217+
218+
# try and find matching mmproj
219+
target = []
220+
root = os.path.dirname(path)
221+
for fname in os.listdir(root):
222+
name, ext = os.path.splitext(fname)
223+
if ext.lower() != ".gguf":
224+
continue
225+
if "mmproj" not in name.lower():
226+
continue
227+
if tenc in name.lower():
228+
target.append(fname)
229+
230+
if len(target) == 0:
231+
logging.error(f"Error: Can't find mmproj file for '{tenc_fname}'! Qwen-Image-Edit will be broken!")
232+
return {}
233+
if len(target) > 1:
234+
logging.error(f"Ambiguous mmproj for text encoder '{tenc_fname}', will use first match.")
235+
236+
logging.info(f"Using mmproj '{target[0]}' for text encoder '{tenc_fname}'.")
237+
target = os.path.join(root, target[0])
238+
vsd = gguf_sd_loader(target, is_text_model=True)
239+
240+
# concat 4D to 5D
241+
if "v.patch_embd.weight.1" in vsd:
242+
w1 = dequantize_tensor(vsd.pop("v.patch_embd.weight"), dtype=torch.float32)
243+
w2 = dequantize_tensor(vsd.pop("v.patch_embd.weight.1"), dtype=torch.float32)
244+
vsd["v.patch_embd.weight"] = torch.stack([w1, w2], dim=2)
245+
246+
# run main replacement
247+
vsd = sd_map_replace(vsd, CLIP_VISION_SD_MAP)
248+
249+
# handle split Q/K/V
250+
if "visual.blocks.0.attn_q.weight" in vsd:
251+
attns = {}
252+
# filter out attentions + group
253+
for k,v in vsd.items():
254+
if any(x in k for x in ["attn_q", "attn_k", "attn_v"]):
255+
k_attn, k_name = k.rsplit(".attn_", 1)
256+
k_attn += ".attn.qkv." + k_name.split(".")[-1]
257+
if k_attn not in attns:
258+
attns[k_attn] = {}
259+
attns[k_attn][k_name] = dequantize_tensor(
260+
v, dtype=(torch.bfloat16 if is_quantized(v) else torch.float16)
261+
)
262+
263+
# recombine
264+
for k,v in attns.items():
265+
suffix = k.split(".")[-1]
266+
vsd[k] = torch.cat([
267+
v[f"q.{suffix}"],
268+
v[f"k.{suffix}"],
269+
v[f"v.{suffix}"],
270+
], dim=0)
271+
del attns
272+
273+
return vsd
274+
188275
def gguf_tokenizer_loader(path, temb_shape):
189276
# convert gguf tokenizer to spiece
190277
logging.info("Attempting to recreate sentencepiece tokenizer from GGUF file metadata...")
@@ -254,6 +341,9 @@ def gguf_clip_loader(path):
254341
sd = sd_map_replace(sd, LLAMA_SD_MAP)
255342
if arch == "llama":
256343
sd = llama_permute(sd, 32, 8) # L3
344+
if arch == "qwen2vl":
345+
vsd = gguf_mmproj_loader(path)
346+
sd.update(vsd)
257347
else:
258348
pass
259349
return sd

0 commit comments

Comments
 (0)