@@ -1789,12 +1789,58 @@ def get_alpha_scales(down_weight, key):
17891789    return  converted_state_dict 
17901790
17911791
1792- def  _convert_non_diffusers_hidream_lora_to_diffusers (state_dict , non_diffusers_prefix = "diffusion_model" ):
1793-     if  not  all (k .startswith (non_diffusers_prefix ) for  k  in  state_dict ):
1794-         raise  ValueError ("Invalid LoRA state dict for HiDream." )
1795-     converted_state_dict  =  {k .removeprefix (f"{ non_diffusers_prefix }  ): v  for  k , v  in  state_dict .items ()}
1796-     converted_state_dict  =  {f"transformer.{ k }  : v  for  k , v  in  converted_state_dict .items ()}
1797-     return  converted_state_dict 
1792+ def  _convert_non_diffusers_hidream_lora_to_diffusers (state_dict ):
1793+     non_diffusers_prefix  =  "diffusion_model" 
1794+     is_kohya  =  all (k .startswith (f"{ non_diffusers_prefix }  ) for  k  in  state_dict )
1795+ 
1796+     def  _convert_kohya (state_dict ):
1797+         converted_state_dict  =  {k .removeprefix (f"{ non_diffusers_prefix }  ): v  for  k , v  in  state_dict .items ()}
1798+         converted_state_dict  =  {f"transformer.{ k }  : v  for  k , v  in  converted_state_dict .items ()}
1799+         return  converted_state_dict 
1800+ 
1801+     if  is_kohya :
1802+         return  _convert_kohya (state_dict )
1803+ 
1804+     else :
1805+         assert  any (k .startswith (("clip_g." , "clip_l." , "t5." , "llama." , "transformer." )) for  k  in  state_dict )
1806+         converted_state_dict  =  {}
1807+         component  =  "transformer" 
1808+         compoent_sd  =  {k : v  for  k , v  in  state_dict .items () if  k .startswith (f"{ component }  )}
1809+ 
1810+         def  _convert_omi (key , state_dict , component ):
1811+             down_key  =  f"{ key }  
1812+             down_weight  =  state_dict .pop (down_key )
1813+             lora_rank  =  down_weight .shape [0 ]
1814+ 
1815+             up_weight_key  =  f"{ key }  
1816+             up_weight  =  state_dict .pop (up_weight_key )
1817+ 
1818+             alpha_key  =  f"{ key }  
1819+             alpha  =  state_dict .pop (alpha_key )
1820+ 
1821+             # scale weight by alpha and dim 
1822+             scale  =  alpha  /  lora_rank 
1823+             # calculate scale_down and scale_up 
1824+             scale_down  =  scale 
1825+             scale_up  =  1.0 
1826+             while  scale_down  *  2  <  scale_up :
1827+                 scale_down  *=  2 
1828+                 scale_up  /=  2 
1829+             down_weight  =  down_weight  *  scale_down 
1830+             up_weight  =  up_weight  *  scale_up 
1831+ 
1832+             diffusers_down_key  =  f"{ key }  
1833+             converted_state_dict [f"{ component } { diffusers_down_key }  ] =  down_weight 
1834+             converted_state_dict [f"{ component } { diffusers_down_key .replace ('.lora_A.' , '.lora_B.' )}  ] =  up_weight 
1835+ 
1836+         all_unique_keys  =  {
1837+             k .replace (".lora_down.weight" , "" ).replace (".lora_up.weight" , "" ).replace (".alpha" , "" )
1838+             for  k  in  compoent_sd 
1839+         }
1840+         for  k  in  all_unique_keys :
1841+             _convert_omi (k , compoent_sd , component = component )
1842+ 
1843+         return  converted_state_dict 
17981844
17991845
18001846def  _convert_non_diffusers_ltxv_lora_to_diffusers (state_dict , non_diffusers_prefix = "diffusion_model" ):
0 commit comments