@@ -1608,3 +1608,64 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
16081608        converted_state_dict [f"transformer.{ key }  ] =  converted_state_dict .pop (key )
16091609
16101610    return  converted_state_dict 
1611+ 
1612+ 
1613+ def  _convert_musubi_wan_lora_to_diffusers (state_dict ):
1614+     # https://github.com/kohya-ss/musubi-tuner 
1615+     converted_state_dict  =  {}
1616+     original_state_dict  =  {k [len ("lora_unet_" ) :]: v  for  k , v  in  state_dict .items ()}
1617+ 
1618+     num_blocks  =  len ({k .split ("blocks_" )[1 ].split ("_" )[0 ] for  k  in  original_state_dict })
1619+     is_i2v_lora  =  any ("k_img"  in  k  for  k  in  original_state_dict ) and  any ("v_img"  in  k  for  k  in  original_state_dict )
1620+ 
1621+     def  get_alpha_scales (down_weight , key ):
1622+         rank  =  down_weight .shape [0 ]
1623+         alpha  =  original_state_dict .pop (key  +  ".alpha" ).item ()
1624+         scale  =  alpha  /  rank   # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here 
1625+         scale_down  =  scale 
1626+         scale_up  =  1.0 
1627+         while  scale_down  *  2  <  scale_up :
1628+             scale_down  *=  2 
1629+             scale_up  /=  2 
1630+         return  scale_down , scale_up 
1631+ 
1632+     for  i  in  range (num_blocks ):
1633+         # Self-attention 
1634+         for  o , c  in  zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1635+             down_weight  =  original_state_dict .pop (f"blocks_{ i } { o }  )
1636+             up_weight  =  original_state_dict .pop (f"blocks_{ i } { o }  )
1637+             scale_down , scale_up  =  get_alpha_scales (down_weight , f"blocks_{ i } { o }  )
1638+             converted_state_dict [f"blocks.{ i } { c }  ] =  down_weight  *  scale_down 
1639+             converted_state_dict [f"blocks.{ i } { c }  ] =  up_weight  *  scale_up 
1640+ 
1641+         # Cross-attention 
1642+         for  o , c  in  zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1643+             down_weight  =  original_state_dict .pop (f"blocks_{ i } { o }  )
1644+             up_weight  =  original_state_dict .pop (f"blocks_{ i } { o }  )
1645+             scale_down , scale_up  =  get_alpha_scales (down_weight , f"blocks_{ i } { o }  )
1646+             converted_state_dict [f"blocks.{ i } { c }  ] =  down_weight  *  scale_down 
1647+             converted_state_dict [f"blocks.{ i } { c }  ] =  up_weight  *  scale_up 
1648+ 
1649+         if  is_i2v_lora :
1650+             for  o , c  in  zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
1651+                 down_weight  =  original_state_dict .pop (f"blocks_{ i } { o }  )
1652+                 up_weight  =  original_state_dict .pop (f"blocks_{ i } { o }  )
1653+                 scale_down , scale_up  =  get_alpha_scales (down_weight , f"blocks_{ i } { o }  )
1654+                 converted_state_dict [f"blocks.{ i } { c }  ] =  down_weight  *  scale_down 
1655+                 converted_state_dict [f"blocks.{ i } { c }  ] =  up_weight  *  scale_up 
1656+ 
1657+         # FFN 
1658+         for  o , c  in  zip (["ffn_0" , "ffn_2" ], ["net.0.proj" , "net.2" ]):
1659+             down_weight  =  original_state_dict .pop (f"blocks_{ i } { o }  )
1660+             up_weight  =  original_state_dict .pop (f"blocks_{ i } { o }  )
1661+             scale_down , scale_up  =  get_alpha_scales (down_weight , f"blocks_{ i } { o }  )
1662+             converted_state_dict [f"blocks.{ i } { c }  ] =  down_weight  *  scale_down 
1663+             converted_state_dict [f"blocks.{ i } { c }  ] =  up_weight  *  scale_up 
1664+ 
1665+     if  len (original_state_dict ) >  0 :
1666+         raise  ValueError (f"`state_dict` should be empty at this point but has { original_state_dict .keys ()= }  )
1667+ 
1668+     for  key  in  list (converted_state_dict .keys ()):
1669+         converted_state_dict [f"transformer.{ key }  ] =  converted_state_dict .pop (key )
1670+ 
1671+     return  converted_state_dict 
0 commit comments