Skip to content

Commit 555b6cc

Browse files
authored
[LoRA] feat: support more Qwen LoRAs from the community. (#12170)
* feat: support more Qwen LoRAs from the community. * revert unrelated changes. * Revert "revert unrelated changes." This reverts commit 82dea55.
1 parent 5b53f67 commit 555b6cc

File tree

2 files changed

+70
-1
lines changed

2 files changed

+70
-1
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2080,6 +2080,74 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
20802080

20812081

20822082
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
2083+
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
2084+
if has_lora_unet:
2085+
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
2086+
2087+
def convert_key(key: str) -> str:
2088+
prefix = "transformer_blocks"
2089+
if "." in key:
2090+
base, suffix = key.rsplit(".", 1)
2091+
else:
2092+
base, suffix = key, ""
2093+
2094+
start = f"{prefix}_"
2095+
rest = base[len(start) :]
2096+
2097+
if "." in rest:
2098+
head, tail = rest.split(".", 1)
2099+
tail = "." + tail
2100+
else:
2101+
head, tail = rest, ""
2102+
2103+
# Protected n-grams that must keep their internal underscores
2104+
protected = {
2105+
# pairs
2106+
("to", "q"),
2107+
("to", "k"),
2108+
("to", "v"),
2109+
("to", "out"),
2110+
("add", "q"),
2111+
("add", "k"),
2112+
("add", "v"),
2113+
("txt", "mlp"),
2114+
("img", "mlp"),
2115+
("txt", "mod"),
2116+
("img", "mod"),
2117+
# triplets
2118+
("add", "q", "proj"),
2119+
("add", "k", "proj"),
2120+
("add", "v", "proj"),
2121+
("to", "add", "out"),
2122+
}
2123+
2124+
prot_by_len = {}
2125+
for ng in protected:
2126+
prot_by_len.setdefault(len(ng), set()).add(ng)
2127+
2128+
parts = head.split("_")
2129+
merged = []
2130+
i = 0
2131+
lengths_desc = sorted(prot_by_len.keys(), reverse=True)
2132+
2133+
while i < len(parts):
2134+
matched = False
2135+
for L in lengths_desc:
2136+
if i + L <= len(parts) and tuple(parts[i : i + L]) in prot_by_len[L]:
2137+
merged.append("_".join(parts[i : i + L]))
2138+
i += L
2139+
matched = True
2140+
break
2141+
if not matched:
2142+
merged.append(parts[i])
2143+
i += 1
2144+
2145+
head_converted = ".".join(merged)
2146+
converted_base = f"{prefix}.{head_converted}{tail}"
2147+
return converted_base + (("." + suffix) if suffix else "")
2148+
2149+
state_dict = {convert_key(k): v for k, v in state_dict.items()}
2150+
20832151
converted_state_dict = {}
20842152
all_keys = list(state_dict.keys())
20852153
down_key = ".lora_down.weight"

src/diffusers/loaders/lora_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6643,7 +6643,8 @@ def lora_state_dict(
66436643
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
66446644

66456645
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
6646-
if has_alphas_in_sd:
6646+
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
6647+
if has_alphas_in_sd or has_lora_unet:
66476648
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
66486649

66496650
out = (state_dict, metadata) if return_lora_metadata else state_dict

0 commit comments

Comments
 (0)