@@ -1608,3 +1608,64 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
16081608 converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
16091609
16101610 return converted_state_dict
1611+
1612+
1613+ def _convert_musubi_wan_lora_to_diffusers (state_dict ):
1614+ # https://github.com/kohya-ss/musubi-tuner
1615+ converted_state_dict = {}
1616+ original_state_dict = {k [len ("lora_unet_" ) :]: v for k , v in state_dict .items ()}
1617+
1618+ num_blocks = len ({k .split ("blocks_" )[1 ].split ("_" )[0 ] for k in original_state_dict })
1619+ is_i2v_lora = any ("k_img" in k for k in original_state_dict ) and any ("v_img" in k for k in original_state_dict )
1620+
1621+ def get_alpha_scales (down_weight , key ):
1622+ rank = down_weight .shape [0 ]
1623+ alpha = original_state_dict .pop (key + ".alpha" ).item ()
1624+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
1625+ scale_down = scale
1626+ scale_up = 1.0
1627+ while scale_down * 2 < scale_up :
1628+ scale_down *= 2
1629+ scale_up /= 2
1630+ return scale_down , scale_up
1631+
1632+ for i in range (num_blocks ):
1633+ # Self-attention
1634+ for o , c in zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1635+ down_weight = original_state_dict .pop (f"blocks_{ i } _self_attn_{ o } .lora_down.weight" )
1636+ up_weight = original_state_dict .pop (f"blocks_{ i } _self_attn_{ o } .lora_up.weight" )
1637+ scale_down , scale_up = get_alpha_scales (down_weight , f"blocks_{ i } _self_attn_{ o } " )
1638+ converted_state_dict [f"blocks.{ i } .attn1.{ c } .lora_A.weight" ] = down_weight * scale_down
1639+ converted_state_dict [f"blocks.{ i } .attn1.{ c } .lora_B.weight" ] = up_weight * scale_up
1640+
1641+ # Cross-attention
1642+ for o , c in zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1643+ down_weight = original_state_dict .pop (f"blocks_{ i } _cross_attn_{ o } .lora_down.weight" )
1644+ up_weight = original_state_dict .pop (f"blocks_{ i } _cross_attn_{ o } .lora_up.weight" )
1645+ scale_down , scale_up = get_alpha_scales (down_weight , f"blocks_{ i } _cross_attn_{ o } " )
1646+ converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_A.weight" ] = down_weight * scale_down
1647+ converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.weight" ] = up_weight * scale_up
1648+
1649+ if is_i2v_lora :
1650+ for o , c in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
1651+ down_weight = original_state_dict .pop (f"blocks_{ i } _cross_attn_{ o } .lora_down.weight" )
1652+ up_weight = original_state_dict .pop (f"blocks_{ i } _cross_attn_{ o } .lora_up.weight" )
1653+ scale_down , scale_up = get_alpha_scales (down_weight , f"blocks_{ i } _cross_attn_{ o } " )
1654+ converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_A.weight" ] = down_weight * scale_down
1655+ converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.weight" ] = up_weight * scale_up
1656+
1657+ # FFN
1658+ for o , c in zip (["ffn_0" , "ffn_2" ], ["net.0.proj" , "net.2" ]):
1659+ down_weight = original_state_dict .pop (f"blocks_{ i } _{ o } .lora_down.weight" )
1660+ up_weight = original_state_dict .pop (f"blocks_{ i } _{ o } .lora_up.weight" )
1661+ scale_down , scale_up = get_alpha_scales (down_weight , f"blocks_{ i } _{ o } " )
1662+ converted_state_dict [f"blocks.{ i } .ffn.{ c } .lora_A.weight" ] = down_weight * scale_down
1663+ converted_state_dict [f"blocks.{ i } .ffn.{ c } .lora_B.weight" ] = up_weight * scale_up
1664+
1665+ if len (original_state_dict ) > 0 :
1666+ raise ValueError (f"`state_dict` should be empty at this point but has { original_state_dict .keys ()= } " )
1667+
1668+ for key in list (converted_state_dict .keys ()):
1669+ converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
1670+
1671+ return converted_state_dict
0 commit comments