@@ -1605,10 +1605,17 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
16051605 if diff_keys :
16061606 for diff_k in diff_keys :
16071607 param = original_state_dict [diff_k ]
1608+ threshold = 1.6e-2
1609+ absdiff = param .abs ().max () - param .abs ().min ()
16081610 all_zero = torch .all (param == 0 ).item ()
1609- if all_zero :
1610- logger .debug (f"Removed { diff_k } key from the state dict as it's all zeros." )
1611+ all_absdiff_lower_than_threshold = absdiff < threshold
1612+ if all_zero or all_absdiff_lower_than_threshold :
1613+ logger .debug (
1614+ f"Removed { diff_k } key from the state dict as it's all zeros, or values lower than hardcoded threshold."
1615+ )
16111616 original_state_dict .pop (diff_k )
1617+ else :
1618+ print (diff_k , absdiff )
16121619
16131620 # For the `diff_b` keys, we treat them as lora_bias.
16141621 # https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
@@ -1655,12 +1662,16 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
16551662
16561663 # FFN
16571664 for o , c in zip (["ffn.0" , "ffn.2" ], ["net.0.proj" , "net.2" ]):
1658- converted_state_dict [f"blocks.{ i } .ffn.{ c } .lora_A.weight" ] = original_state_dict .pop (
1659- f"blocks.{ i } .{ o } .{ lora_down_key } .weight"
1660- )
1661- converted_state_dict [f"blocks.{ i } .ffn.{ c } .lora_B.weight" ] = original_state_dict .pop (
1662- f"blocks.{ i } .{ o } .{ lora_up_key } .weight"
1663- )
1665+ original_key = f"blocks.{ i } .{ o } .{ lora_down_key } .weight"
1666+ converted_key = f"blocks.{ i } .ffn.{ c } .lora_A.weight"
1667+ if original_key in original_state_dict :
1668+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1669+
1670+ original_key = f"blocks.{ i } .{ o } .{ lora_up_key } .weight"
1671+ converted_key = f"blocks.{ i } .ffn.{ c } .lora_B.weight"
1672+ if original_key in original_state_dict :
1673+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1674+
16641675 if f"blocks.{ i } .{ o } .diff_b" in original_state_dict :
16651676 converted_state_dict [f"blocks.{ i } .ffn.{ c } .lora_B.bias" ] = original_state_dict .pop (
16661677 f"blocks.{ i } .{ o } .diff_b"
@@ -1669,12 +1680,16 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
16691680 # Remaining.
16701681 if original_state_dict :
16711682 if any ("time_projection" in k for k in original_state_dict ):
1672- converted_state_dict ["condition_embedder.time_proj.lora_A.weight" ] = original_state_dict .pop (
1673- f"time_projection.1.{ lora_down_key } .weight"
1674- )
1675- converted_state_dict ["condition_embedder.time_proj.lora_B.weight" ] = original_state_dict .pop (
1676- f"time_projection.1.{ lora_up_key } .weight"
1677- )
1683+ original_key = f"time_projection.1.{ lora_down_key } .weight"
1684+ converted_key = "condition_embedder.time_proj.lora_A.weight"
1685+ if original_key in original_state_dict :
1686+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1687+
1688+ original_key = f"time_projection.1.{ lora_up_key } .weight"
1689+ converted_key = "condition_embedder.time_proj.lora_B.weight"
1690+ if original_key in original_state_dict :
1691+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1692+
16781693 if "time_projection.1.diff_b" in original_state_dict :
16791694 converted_state_dict ["condition_embedder.time_proj.lora_B.bias" ] = original_state_dict .pop (
16801695 "time_projection.1.diff_b"
@@ -1709,6 +1724,23 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
17091724 original_state_dict .pop (f"{ text_time } .{ b_n } .diff_b" )
17101725 )
17111726
1727+ for img_ours , img_theirs in [
1728+ (
1729+ "ff.net.0.proj" ,
1730+ "img_emb.proj.1"
1731+ ),
1732+ ("ff.net.2" , "img_emb.proj.3" ),
1733+ ]:
1734+ original_key = f"{ img_theirs } .{ lora_down_key } .weight"
1735+ converted_key = f"condition_embedder.image_embedder.{ img_ours } .lora_A.weight"
1736+ if original_key in original_state_dict :
1737+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1738+
1739+ original_key = f"{ img_theirs } .{ lora_up_key } .weight"
1740+ converted_key = f"condition_embedder.image_embedder.{ img_ours } .lora_B.weight"
1741+ if original_key in original_state_dict :
1742+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1743+
17121744 if len (original_state_dict ) > 0 :
17131745 diff = all (".diff" in k for k in original_state_dict )
17141746 if diff :
0 commit comments