Skip to content

Commit 1e485a9

Browse files
committed
feat: support more Qwen LoRAs from the community.
1 parent 76c809e commit 1e485a9

File tree

4 files changed

+72
-3
lines changed

4 files changed

+72
-3
lines changed

src/diffusers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,10 +489,10 @@
489489
"PixArtAlphaPipeline",
490490
"PixArtSigmaPAGPipeline",
491491
"PixArtSigmaPipeline",
492+
"QwenImageEditPipeline",
492493
"QwenImageImg2ImgPipeline",
493494
"QwenImageInpaintPipeline",
494495
"QwenImagePipeline",
495-
"QwenImageEditPipeline",
496496
"ReduxImageEncoder",
497497
"SanaControlNetPipeline",
498498
"SanaPAGPipeline",

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2080,6 +2080,74 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
20802080

20812081

20822082
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
2083+
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
2084+
if has_lora_unet:
2085+
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
2086+
2087+
def convert_key(key: str) -> str:
2088+
prefix = "transformer_blocks"
2089+
if "." in key:
2090+
base, suffix = key.rsplit(".", 1)
2091+
else:
2092+
base, suffix = key, ""
2093+
2094+
start = f"{prefix}_"
2095+
rest = base[len(start) :]
2096+
2097+
if "." in rest:
2098+
head, tail = rest.split(".", 1)
2099+
tail = "." + tail
2100+
else:
2101+
head, tail = rest, ""
2102+
2103+
# Protected n-grams that must keep their internal underscores
2104+
protected = {
2105+
# pairs
2106+
("to", "q"),
2107+
("to", "k"),
2108+
("to", "v"),
2109+
("to", "out"),
2110+
("add", "q"),
2111+
("add", "k"),
2112+
("add", "v"),
2113+
("txt", "mlp"),
2114+
("img", "mlp"),
2115+
("txt", "mod"),
2116+
("img", "mod"),
2117+
# triplets
2118+
("add", "q", "proj"),
2119+
("add", "k", "proj"),
2120+
("add", "v", "proj"),
2121+
("to", "add", "out"),
2122+
}
2123+
2124+
prot_by_len = {}
2125+
for ng in protected:
2126+
prot_by_len.setdefault(len(ng), set()).add(ng)
2127+
2128+
parts = head.split("_")
2129+
merged = []
2130+
i = 0
2131+
lengths_desc = sorted(prot_by_len.keys(), reverse=True)
2132+
2133+
while i < len(parts):
2134+
matched = False
2135+
for L in lengths_desc:
2136+
if i + L <= len(parts) and tuple(parts[i : i + L]) in prot_by_len[L]:
2137+
merged.append("_".join(parts[i : i + L]))
2138+
i += L
2139+
matched = True
2140+
break
2141+
if not matched:
2142+
merged.append(parts[i])
2143+
i += 1
2144+
2145+
head_converted = ".".join(merged)
2146+
converted_base = f"{prefix}.{head_converted}{tail}"
2147+
return converted_base + (("." + suffix) if suffix else "")
2148+
2149+
state_dict = {convert_key(k): v for k, v in state_dict.items()}
2150+
20832151
converted_state_dict = {}
20842152
all_keys = list(state_dict.keys())
20852153
down_key = ".lora_down.weight"

src/diffusers/loaders/lora_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6643,7 +6643,8 @@ def lora_state_dict(
66436643
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
66446644

66456645
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
6646-
if has_alphas_in_sd:
6646+
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
6647+
if has_alphas_in_sd or has_lora_unet:
66476648
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
66486649

66496650
out = (state_dict, metadata) if return_lora_metadata else state_dict

src/diffusers/pipelines/qwenimage/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
else:
2525
_import_structure["modeling_qwenimage"] = ["ReduxImageEncoder"]
2626
_import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"]
27+
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
2728
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
2829
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
29-
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
3030

3131
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
3232
try:

0 commit comments

Comments
 (0)