@@ -2129,6 +2129,10 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
21292129
21302130
21312131def  _convert_non_diffusers_qwen_lora_to_diffusers (state_dict ):
2132+     has_diffusion_model  =  any (k .startswith ("diffusion_model." ) for  k  in  state_dict )
2133+     if  has_diffusion_model :
2134+         state_dict  =  {k .removeprefix ("diffusion_model." ): v  for  k , v  in  state_dict .items ()}
2135+ 
21322136    has_lora_unet  =  any (k .startswith ("lora_unet_" ) for  k  in  state_dict )
21332137    if  has_lora_unet :
21342138        state_dict  =  {k .removeprefix ("lora_unet_" ): v  for  k , v  in  state_dict .items ()}
@@ -2201,29 +2205,44 @@ def convert_key(key: str) -> str:
22012205    all_keys  =  list (state_dict .keys ())
22022206    down_key  =  ".lora_down.weight" 
22032207    up_key  =  ".lora_up.weight" 
2208+     a_key  =  ".lora_A.weight" 
2209+     b_key  =  ".lora_B.weight" 
22042210
2205-     def  get_alpha_scales (down_weight , alpha_key ):
2206-         rank  =  down_weight .shape [0 ]
2207-         alpha  =  state_dict .pop (alpha_key ).item ()
2208-         scale  =  alpha  /  rank   # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here 
2209-         scale_down  =  scale 
2210-         scale_up  =  1.0 
2211-         while  scale_down  *  2  <  scale_up :
2212-             scale_down  *=  2 
2213-             scale_up  /=  2 
2214-         return  scale_down , scale_up 
2211+     has_non_diffusers_lora_id  =  any (down_key  in  k  or  up_key  in  k  for  k  in  all_keys )
2212+     has_diffusers_lora_id  =  any (a_key  in  k  or  b_key  in  k  for  k  in  all_keys )
22152213
2216-     for  k  in  all_keys :
2217-         if  k .endswith (down_key ):
2218-             diffusers_down_key  =  k .replace (down_key , ".lora_A.weight" )
2219-             diffusers_up_key  =  k .replace (down_key , up_key ).replace (up_key , ".lora_B.weight" )
2220-             alpha_key  =  k .replace (down_key , ".alpha" )
2221- 
2222-             down_weight  =  state_dict .pop (k )
2223-             up_weight  =  state_dict .pop (k .replace (down_key , up_key ))
2224-             scale_down , scale_up  =  get_alpha_scales (down_weight , alpha_key )
2225-             converted_state_dict [diffusers_down_key ] =  down_weight  *  scale_down 
2226-             converted_state_dict [diffusers_up_key ] =  up_weight  *  scale_up 
2214+     if  has_non_diffusers_lora_id :
2215+ 
2216+         def  get_alpha_scales (down_weight , alpha_key ):
2217+             rank  =  down_weight .shape [0 ]
2218+             alpha  =  state_dict .pop (alpha_key ).item ()
2219+             scale  =  alpha  /  rank   # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here 
2220+             scale_down  =  scale 
2221+             scale_up  =  1.0 
2222+             while  scale_down  *  2  <  scale_up :
2223+                 scale_down  *=  2 
2224+                 scale_up  /=  2 
2225+             return  scale_down , scale_up 
2226+ 
2227+         for  k  in  all_keys :
2228+             if  k .endswith (down_key ):
2229+                 diffusers_down_key  =  k .replace (down_key , ".lora_A.weight" )
2230+                 diffusers_up_key  =  k .replace (down_key , up_key ).replace (up_key , ".lora_B.weight" )
2231+                 alpha_key  =  k .replace (down_key , ".alpha" )
2232+ 
2233+                 down_weight  =  state_dict .pop (k )
2234+                 up_weight  =  state_dict .pop (k .replace (down_key , up_key ))
2235+                 scale_down , scale_up  =  get_alpha_scales (down_weight , alpha_key )
2236+                 converted_state_dict [diffusers_down_key ] =  down_weight  *  scale_down 
2237+                 converted_state_dict [diffusers_up_key ] =  up_weight  *  scale_up 
2238+ 
2239+     # Already in diffusers format (lora_A/lora_B), just pop 
2240+     elif  has_diffusers_lora_id :
2241+         for  k  in  all_keys :
2242+             if  a_key  in  k  or  b_key  in  k :
2243+                 converted_state_dict [k ] =  state_dict .pop (k )
2244+             elif  ".alpha"  in  k :
2245+                 state_dict .pop (k )
22272246
22282247    if  len (state_dict ) >  0 :
22292248        raise  ValueError (f"`state_dict` should be empty at this point but has { state_dict .keys ()= }  " )
0 commit comments