Skip to content

Commit f567f56

Browse files
committed
feat: support qwen lightning lora.
1 parent ff9a387 commit f567f56

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2077,3 +2077,39 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
20772077
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
20782078
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
20792079
return converted_state_dict
2080+
2081+
2082+
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
2083+
converted_state_dict = {}
2084+
all_keys = list(state_dict.keys())
2085+
down_key = ".lora_down.weight"
2086+
up_key = ".lora_up.weight"
2087+
2088+
def get_alpha_scales(down_weight, alpha_key):
2089+
rank = down_weight.shape[0]
2090+
alpha = state_dict.pop(alpha_key).item()
2091+
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
2092+
scale_down = scale
2093+
scale_up = 1.0
2094+
while scale_down * 2 < scale_up:
2095+
scale_down *= 2
2096+
scale_up /= 2
2097+
return scale_down, scale_up
2098+
2099+
for k in all_keys:
2100+
if k.endswith(down_key):
2101+
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
2102+
diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
2103+
alpha_key = k.replace(down_key, ".alpha")
2104+
2105+
down_weight = state_dict.pop(k)
2106+
up_weight = state_dict.pop(k.replace(down_key, up_key))
2107+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
2108+
converted_state_dict[diffusers_down_key] = down_weight * scale_down
2109+
converted_state_dict[diffusers_up_key] = up_weight * scale_up
2110+
2111+
if len(state_dict) > 0:
2112+
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
2113+
2114+
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
2115+
return converted_state_dict

src/diffusers/loaders/lora_pipeline.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
_convert_non_diffusers_lora_to_diffusers,
5050
_convert_non_diffusers_ltxv_lora_to_diffusers,
5151
_convert_non_diffusers_lumina2_lora_to_diffusers,
52+
_convert_non_diffusers_qwen_lora_to_diffusers,
5253
_convert_non_diffusers_wan_lora_to_diffusers,
5354
_convert_xlabs_flux_lora_to_diffusers,
5455
_maybe_map_sgm_blocks_to_diffusers,
@@ -6642,6 +6643,10 @@ def lora_state_dict(
66426643
logger.warning(warn_msg)
66436644
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
66446645

6646+
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
6647+
if has_alphas_in_sd:
6648+
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
6649+
66456650
out = (state_dict, metadata) if return_lora_metadata else state_dict
66466651
return out
66476652

0 commit comments

Comments
 (0)