@@ -1826,23 +1826,23 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18261826 lora_down_key = "lora_A" if any ("lora_A" in k for k in original_state_dict ) else "lora_down"
18271827 lora_up_key = "lora_B" if any ("lora_B" in k for k in original_state_dict ) else "lora_up"
18281828
1829- diff_keys = [ k for k in original_state_dict if k . endswith (( ".diff_b" , ".diff" ))]
1830- if diff_keys :
1831- for diff_k in diff_keys :
1832- param = original_state_dict [ diff_k ]
1833- # The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3,
1834- # and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond
1835- # to norm layers. Ignoring them is the best option at the moment until a better solution is found. It
1836- # is okay to ignore because they do not affect the model output in a significant manner.
1837- threshold = 1.6e-2
1838- absdiff = param . abs (). max () - param . abs (). min ()
1839- all_zero = torch . all ( param == 0 ). item ()
1840- all_absdiff_lower_than_threshold = absdiff < threshold
1841- if all_zero or all_absdiff_lower_than_threshold :
1842- logger . debug (
1843- f"Removed { diff_k } key from the state dict as it's all zeros, or values lower than hardcoded threshold."
1844- )
1845- original_state_dict . pop ( diff_k )
1829+ for key in list ( original_state_dict . keys ()):
1830+ if key . endswith (( ".diff" , ".diff_b" )) and "norm" in key :
1831+ # NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
1832+ # in future if needed and they are not zeroed.
1833+ original_state_dict . pop ( key )
1834+ logger . debug ( f"Removing { key } key from the state dict as it is a norm diff key. This is unsupported." )
1835+
1836+ if "time_projection" in key :
1837+ # AccVideo lora has diff bias keys but not the weight keys. This causes a weird problem where
1838+ # our lora config adds the time proj lora layers, but we don't have the weights for them.
1839+ # CausVid lora has the weight keys and the bias keys.
1840+ # This mismatch causes problems with the automatic lora config detection. The best way out is
1841+ # to simply drop the time projection layer keys from the state dict. It is safe because
1842+ # the most important layers are QKVO projections anyway, and doing this does not seem to impact
1843+ # model quality in a quantifiable way.
1844+ original_state_dict . pop ( key )
1845+ logger . debug ( f"Removing { key } key from the state dict." )
18461846
18471847 # For the `diff_b` keys, we treat them as lora_bias.
18481848 # https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
@@ -1918,22 +1918,6 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
19181918
19191919 # Remaining.
19201920 if original_state_dict :
1921- if any ("time_projection" in k for k in original_state_dict ):
1922- original_key = f"time_projection.1.{ lora_down_key } .weight"
1923- converted_key = "condition_embedder.time_proj.lora_A.weight"
1924- if original_key in original_state_dict :
1925- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1926-
1927- original_key = f"time_projection.1.{ lora_up_key } .weight"
1928- converted_key = "condition_embedder.time_proj.lora_B.weight"
1929- if original_key in original_state_dict :
1930- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1931-
1932- if "time_projection.1.diff_b" in original_state_dict :
1933- converted_state_dict ["condition_embedder.time_proj.lora_B.bias" ] = original_state_dict .pop (
1934- "time_projection.1.diff_b"
1935- )
1936-
19371921 if any ("head.head" in k for k in state_dict ):
19381922 converted_state_dict ["proj_out.lora_A.weight" ] = original_state_dict .pop (
19391923 f"head.head.{ lora_down_key } .weight"
0 commit comments