@@ -2077,3 +2077,39 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
20772077 converted_state_dict = {k .removeprefix (f"{ non_diffusers_prefix } ." ): v for k , v in state_dict .items ()}
20782078 converted_state_dict = {f"transformer.{ k } " : v for k , v in converted_state_dict .items ()}
20792079 return converted_state_dict
2080+
2081+
2082+ def _convert_non_diffusers_qwen_lora_to_diffusers (state_dict ):
2083+ converted_state_dict = {}
2084+ all_keys = list (state_dict .keys ())
2085+ down_key = ".lora_down.weight"
2086+ up_key = ".lora_up.weight"
2087+
2088+ def get_alpha_scales (down_weight , alpha_key ):
2089+ rank = down_weight .shape [0 ]
2090+ alpha = state_dict .pop (alpha_key ).item ()
2091+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
2092+ scale_down = scale
2093+ scale_up = 1.0
2094+ while scale_down * 2 < scale_up :
2095+ scale_down *= 2
2096+ scale_up /= 2
2097+ return scale_down , scale_up
2098+
2099+ for k in all_keys :
2100+ if k .endswith (down_key ):
2101+ diffusers_down_key = k .replace (down_key , ".lora_A.weight" )
2102+ diffusers_up_key = k .replace (down_key , up_key ).replace (up_key , ".lora_B.weight" )
2103+ alpha_key = k .replace (down_key , ".alpha" )
2104+
2105+ down_weight = state_dict .pop (k )
2106+ up_weight = state_dict .pop (k .replace (down_key , up_key ))
2107+ scale_down , scale_up = get_alpha_scales (down_weight , alpha_key )
2108+ converted_state_dict [diffusers_down_key ] = down_weight * scale_down
2109+ converted_state_dict [diffusers_up_key ] = up_weight * scale_up
2110+
2111+ if len (state_dict ) > 0 :
2112+ raise ValueError (f"`state_dict` should be empty at this point but has { state_dict .keys ()= } " )
2113+
2114+ converted_state_dict = {f"transformer.{ k } " : v for k , v in converted_state_dict .items ()}
2115+ return converted_state_dict
0 commit comments