Skip to content

Commit 0e0467c

Browse files
committed
actually, better fix
1 parent 2f5cf33 commit 0e0467c

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1825,6 +1825,9 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18251825
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
18261826
lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
18271827
lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
1828+
has_time_projection_weight = any(
1829+
k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
1830+
)
18281831

18291832
for key in list(original_state_dict.keys()):
18301833
if key.endswith((".diff", ".diff_b")) and "norm" in key:
@@ -1833,16 +1836,11 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18331836
original_state_dict.pop(key)
18341837
logger.debug(f"Removing {key} key from the state dict as it is a norm diff key. This is unsupported.")
18351838

1836-
if "time_projection" in key:
1839+
if "time_projection" in key and not has_time_projection_weight:
18371840
# AccVideo lora has diff bias keys but not the weight keys. This causes a weird problem where
18381841
# our lora config adds the time proj lora layers, but we don't have the weights for them.
18391842
# CausVid lora has the weight keys and the bias keys.
1840-
# This mismatch causes problems with the automatic lora config detection. The best way out is
1841-
# to simply drop the time projection layer keys from the state dict. It is safe because
1842-
# the most important layers are QKVO projections anyway, and doing this does not seem to impact
1843-
# model quality in a quantifiable way.
18441843
original_state_dict.pop(key)
1845-
logger.debug(f"Removing {key} key from the state dict.")
18461844

18471845
# For the `diff_b` keys, we treat them as lora_bias.
18481846
# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
@@ -1918,6 +1916,22 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
19181916

19191917
# Remaining.
19201918
if original_state_dict:
1919+
if any("time_projection" in k for k in original_state_dict):
1920+
original_key = f"time_projection.1.{lora_down_key}.weight"
1921+
converted_key = "condition_embedder.time_proj.lora_A.weight"
1922+
if original_key in original_state_dict:
1923+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1924+
1925+
original_key = f"time_projection.1.{lora_up_key}.weight"
1926+
converted_key = "condition_embedder.time_proj.lora_B.weight"
1927+
if original_key in original_state_dict:
1928+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1929+
1930+
if "time_projection.1.diff_b" in original_state_dict:
1931+
converted_state_dict["condition_embedder.time_proj.lora_B.bias"] = original_state_dict.pop(
1932+
"time_projection.1.diff_b"
1933+
)
1934+
19211935
if any("head.head" in k for k in state_dict):
19221936
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
19231937
f"head.head.{lora_down_key}.weight"

0 commit comments

Comments
 (0)