@@ -2129,6 +2129,10 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
21292129
21302130
21312131def _convert_non_diffusers_qwen_lora_to_diffusers (state_dict ):
2132+ has_diffusion_model = any (k .startswith ("diffusion_model." ) for k in state_dict )
2133+ if has_diffusion_model :
2134+ state_dict = {k .removeprefix ("diffusion_model." ): v for k , v in state_dict .items ()}
2135+
21322136 has_lora_unet = any (k .startswith ("lora_unet_" ) for k in state_dict )
21332137 if has_lora_unet :
21342138 state_dict = {k .removeprefix ("lora_unet_" ): v for k , v in state_dict .items ()}
@@ -2201,29 +2205,44 @@ def convert_key(key: str) -> str:
22012205 all_keys = list (state_dict .keys ())
22022206 down_key = ".lora_down.weight"
22032207 up_key = ".lora_up.weight"
2208+ a_key = ".lora_A.weight"
2209+ b_key = ".lora_B.weight"
22042210
2205- def get_alpha_scales (down_weight , alpha_key ):
2206- rank = down_weight .shape [0 ]
2207- alpha = state_dict .pop (alpha_key ).item ()
2208- scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
2209- scale_down = scale
2210- scale_up = 1.0
2211- while scale_down * 2 < scale_up :
2212- scale_down *= 2
2213- scale_up /= 2
2214- return scale_down , scale_up
2211+ has_non_diffusers_lora_id = any (down_key in k or up_key in k for k in all_keys )
2212+ has_diffusers_lora_id = any (a_key in k or b_key in k for k in all_keys )
22152213
2216- for k in all_keys :
2217- if k .endswith (down_key ):
2218- diffusers_down_key = k .replace (down_key , ".lora_A.weight" )
2219- diffusers_up_key = k .replace (down_key , up_key ).replace (up_key , ".lora_B.weight" )
2220- alpha_key = k .replace (down_key , ".alpha" )
2221-
2222- down_weight = state_dict .pop (k )
2223- up_weight = state_dict .pop (k .replace (down_key , up_key ))
2224- scale_down , scale_up = get_alpha_scales (down_weight , alpha_key )
2225- converted_state_dict [diffusers_down_key ] = down_weight * scale_down
2226- converted_state_dict [diffusers_up_key ] = up_weight * scale_up
2214+ if has_non_diffusers_lora_id :
2215+
2216+ def get_alpha_scales (down_weight , alpha_key ):
2217+ rank = down_weight .shape [0 ]
2218+ alpha = state_dict .pop (alpha_key ).item ()
2219+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
2220+ scale_down = scale
2221+ scale_up = 1.0
2222+ while scale_down * 2 < scale_up :
2223+ scale_down *= 2
2224+ scale_up /= 2
2225+ return scale_down , scale_up
2226+
2227+ for k in all_keys :
2228+ if k .endswith (down_key ):
2229+ diffusers_down_key = k .replace (down_key , ".lora_A.weight" )
2230+ diffusers_up_key = k .replace (down_key , up_key ).replace (up_key , ".lora_B.weight" )
2231+ alpha_key = k .replace (down_key , ".alpha" )
2232+
2233+ down_weight = state_dict .pop (k )
2234+ up_weight = state_dict .pop (k .replace (down_key , up_key ))
2235+ scale_down , scale_up = get_alpha_scales (down_weight , alpha_key )
2236+ converted_state_dict [diffusers_down_key ] = down_weight * scale_down
2237+ converted_state_dict [diffusers_up_key ] = up_weight * scale_up
2238+
2239+ # Already in diffusers format (lora_A/lora_B), just pop
2240+ elif has_diffusers_lora_id :
2241+ for k in all_keys :
2242+ if a_key in k or b_key in k :
2243+ converted_state_dict [k ] = state_dict .pop (k )
2244+ elif ".alpha" in k :
2245+ state_dict .pop (k )
22272246
22282247 if len (state_dict ) > 0 :
22292248 raise ValueError (f"`state_dict` should be empty at this point but has { state_dict .keys ()= } " )
0 commit comments