Skip to content

Commit 2f5cf33

Browse files
committed
fix
1 parent 8c938fb commit 2f5cf33

File tree

1 file changed

+17
-33
lines changed

1 file changed

+17
-33
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 17 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1826,23 +1826,23 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18261826
lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
18271827
lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
18281828

1829-
diff_keys = [k for k in original_state_dict if k.endswith((".diff_b", ".diff"))]
1830-
if diff_keys:
1831-
for diff_k in diff_keys:
1832-
param = original_state_dict[diff_k]
1833-
# The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3,
1834-
# and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond
1835-
# to norm layers. Ignoring them is the best option at the moment until a better solution is found. It
1836-
# is okay to ignore because they do not affect the model output in a significant manner.
1837-
threshold = 1.6e-2
1838-
absdiff = param.abs().max() - param.abs().min()
1839-
all_zero = torch.all(param == 0).item()
1840-
all_absdiff_lower_than_threshold = absdiff < threshold
1841-
if all_zero or all_absdiff_lower_than_threshold:
1842-
logger.debug(
1843-
f"Removed {diff_k} key from the state dict as it's all zeros, or values lower than hardcoded threshold."
1844-
)
1845-
original_state_dict.pop(diff_k)
1829+
for key in list(original_state_dict.keys()):
1830+
if key.endswith((".diff", ".diff_b")) and "norm" in key:
1831+
# NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
1832+
# in future if needed and they are not zeroed.
1833+
original_state_dict.pop(key)
1834+
logger.debug(f"Removing {key} key from the state dict as it is a norm diff key. This is unsupported.")
1835+
1836+
if "time_projection" in key:
1837+
# AccVideo lora has diff bias keys but not the weight keys. This causes a weird problem where
1838+
# our lora config adds the time proj lora layers, but we don't have the weights for them.
1839+
# CausVid lora has the weight keys and the bias keys.
1840+
# This mismatch causes problems with the automatic lora config detection. The best way out is
1841+
# to simply drop the time projection layer keys from the state dict. It is safe because
1842+
# the most important layers are QKVO projections anyway, and doing this does not seem to impact
1843+
# model quality in a quantifiable way.
1844+
original_state_dict.pop(key)
1845+
logger.debug(f"Removing {key} key from the state dict.")
18461846

18471847
# For the `diff_b` keys, we treat them as lora_bias.
18481848
# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
@@ -1918,22 +1918,6 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
19181918

19191919
# Remaining.
19201920
if original_state_dict:
1921-
if any("time_projection" in k for k in original_state_dict):
1922-
original_key = f"time_projection.1.{lora_down_key}.weight"
1923-
converted_key = "condition_embedder.time_proj.lora_A.weight"
1924-
if original_key in original_state_dict:
1925-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1926-
1927-
original_key = f"time_projection.1.{lora_up_key}.weight"
1928-
converted_key = "condition_embedder.time_proj.lora_B.weight"
1929-
if original_key in original_state_dict:
1930-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1931-
1932-
if "time_projection.1.diff_b" in original_state_dict:
1933-
converted_state_dict["condition_embedder.time_proj.lora_B.bias"] = original_state_dict.pop(
1934-
"time_projection.1.diff_b"
1935-
)
1936-
19371921
if any("head.head" in k for k in state_dict):
19381922
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
19391923
f"head.head.{lora_down_key}.weight"

0 commit comments

Comments
 (0)