Skip to content

Commit 901da9d

Browse files
Fix lora conversion function for ai-toolkit Qwen Image LoRAs (huggingface#12261)
* Fix lora conversion function for ai-toolkit Qwen Image LoRAs * add forgotten parenthesis * remove space new line * update pipeline * detect if arrow or letter * remove whitespaces * style * apply suggestion * apply suggestion * apply suggestion --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 67ffa70 commit 901da9d

File tree

2 files changed

+42
-22
lines changed

2 files changed

+42
-22
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2129,6 +2129,10 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
21292129

21302130

21312131
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
2132+
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
2133+
if has_diffusion_model:
2134+
state_dict = {k.removeprefix("diffusion_model."): v for k, v in state_dict.items()}
2135+
21322136
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
21332137
if has_lora_unet:
21342138
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
@@ -2201,29 +2205,44 @@ def convert_key(key: str) -> str:
22012205
all_keys = list(state_dict.keys())
22022206
down_key = ".lora_down.weight"
22032207
up_key = ".lora_up.weight"
2208+
a_key = ".lora_A.weight"
2209+
b_key = ".lora_B.weight"
22042210

2205-
def get_alpha_scales(down_weight, alpha_key):
2206-
rank = down_weight.shape[0]
2207-
alpha = state_dict.pop(alpha_key).item()
2208-
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
2209-
scale_down = scale
2210-
scale_up = 1.0
2211-
while scale_down * 2 < scale_up:
2212-
scale_down *= 2
2213-
scale_up /= 2
2214-
return scale_down, scale_up
2211+
has_non_diffusers_lora_id = any(down_key in k or up_key in k for k in all_keys)
2212+
has_diffusers_lora_id = any(a_key in k or b_key in k for k in all_keys)
22152213

2216-
for k in all_keys:
2217-
if k.endswith(down_key):
2218-
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
2219-
diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
2220-
alpha_key = k.replace(down_key, ".alpha")
2221-
2222-
down_weight = state_dict.pop(k)
2223-
up_weight = state_dict.pop(k.replace(down_key, up_key))
2224-
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
2225-
converted_state_dict[diffusers_down_key] = down_weight * scale_down
2226-
converted_state_dict[diffusers_up_key] = up_weight * scale_up
2214+
if has_non_diffusers_lora_id:
2215+
2216+
def get_alpha_scales(down_weight, alpha_key):
2217+
rank = down_weight.shape[0]
2218+
alpha = state_dict.pop(alpha_key).item()
2219+
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
2220+
scale_down = scale
2221+
scale_up = 1.0
2222+
while scale_down * 2 < scale_up:
2223+
scale_down *= 2
2224+
scale_up /= 2
2225+
return scale_down, scale_up
2226+
2227+
for k in all_keys:
2228+
if k.endswith(down_key):
2229+
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
2230+
diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
2231+
alpha_key = k.replace(down_key, ".alpha")
2232+
2233+
down_weight = state_dict.pop(k)
2234+
up_weight = state_dict.pop(k.replace(down_key, up_key))
2235+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
2236+
converted_state_dict[diffusers_down_key] = down_weight * scale_down
2237+
converted_state_dict[diffusers_up_key] = up_weight * scale_up
2238+
2239+
# Already in diffusers format (lora_A/lora_B), just pop
2240+
elif has_diffusers_lora_id:
2241+
for k in all_keys:
2242+
if a_key in k or b_key in k:
2243+
converted_state_dict[k] = state_dict.pop(k)
2244+
elif ".alpha" in k:
2245+
state_dict.pop(k)
22272246

22282247
if len(state_dict) > 0:
22292248
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")

src/diffusers/loaders/lora_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6684,7 +6684,8 @@ def lora_state_dict(
66846684

66856685
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
66866686
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
6687-
if has_alphas_in_sd or has_lora_unet:
6687+
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
6688+
if has_alphas_in_sd or has_lora_unet or has_diffusion_model:
66886689
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
66896690

66906691
out = (state_dict, metadata) if return_lora_metadata else state_dict

0 commit comments

Comments
 (0)