Skip to content

Commit dc6f8f5

Browse files
committed
update
1 parent 00b179f commit dc6f8f5

File tree

1 file changed

+46
-14
lines changed

1 file changed

+46
-14
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,10 +1605,17 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
16051605
if diff_keys:
16061606
for diff_k in diff_keys:
16071607
param = original_state_dict[diff_k]
1608+
threshold = 1.6e-2
1609+
absdiff = param.abs().max() - param.abs().min()
16081610
all_zero = torch.all(param == 0).item()
1609-
if all_zero:
1610-
logger.debug(f"Removed {diff_k} key from the state dict as it's all zeros.")
1611+
all_absdiff_lower_than_threshold = absdiff < threshold
1612+
if all_zero or all_absdiff_lower_than_threshold:
1613+
logger.debug(
1614+
f"Removed {diff_k} key from the state dict as it's all zeros, or values lower than hardcoded threshold."
1615+
)
16111616
original_state_dict.pop(diff_k)
1617+
else:
1618+
print(diff_k, absdiff)
16121619

16131620
# For the `diff_b` keys, we treat them as lora_bias.
16141621
# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
@@ -1655,12 +1662,16 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
16551662

16561663
# FFN
16571664
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
1658-
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop(
1659-
f"blocks.{i}.{o}.{lora_down_key}.weight"
1660-
)
1661-
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop(
1662-
f"blocks.{i}.{o}.{lora_up_key}.weight"
1663-
)
1665+
original_key = f"blocks.{i}.{o}.{lora_down_key}.weight"
1666+
converted_key = f"blocks.{i}.ffn.{c}.lora_A.weight"
1667+
if original_key in original_state_dict:
1668+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1669+
1670+
original_key = f"blocks.{i}.{o}.{lora_up_key}.weight"
1671+
converted_key = f"blocks.{i}.ffn.{c}.lora_B.weight"
1672+
if original_key in original_state_dict:
1673+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1674+
16641675
if f"blocks.{i}.{o}.diff_b" in original_state_dict:
16651676
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.bias"] = original_state_dict.pop(
16661677
f"blocks.{i}.{o}.diff_b"
@@ -1669,12 +1680,16 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
16691680
# Remaining.
16701681
if original_state_dict:
16711682
if any("time_projection" in k for k in original_state_dict):
1672-
converted_state_dict["condition_embedder.time_proj.lora_A.weight"] = original_state_dict.pop(
1673-
f"time_projection.1.{lora_down_key}.weight"
1674-
)
1675-
converted_state_dict["condition_embedder.time_proj.lora_B.weight"] = original_state_dict.pop(
1676-
f"time_projection.1.{lora_up_key}.weight"
1677-
)
1683+
original_key = f"time_projection.1.{lora_down_key}.weight"
1684+
converted_key = "condition_embedder.time_proj.lora_A.weight"
1685+
if original_key in original_state_dict:
1686+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1687+
1688+
original_key = f"time_projection.1.{lora_up_key}.weight"
1689+
converted_key = "condition_embedder.time_proj.lora_B.weight"
1690+
if original_key in original_state_dict:
1691+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1692+
16781693
if "time_projection.1.diff_b" in original_state_dict:
16791694
converted_state_dict["condition_embedder.time_proj.lora_B.bias"] = original_state_dict.pop(
16801695
"time_projection.1.diff_b"
@@ -1709,6 +1724,23 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
17091724
original_state_dict.pop(f"{text_time}.{b_n}.diff_b")
17101725
)
17111726

1727+
for img_ours, img_theirs in [
1728+
(
1729+
"ff.net.0.proj",
1730+
"img_emb.proj.1"
1731+
),
1732+
("ff.net.2", "img_emb.proj.3"),
1733+
]:
1734+
original_key = f"{img_theirs}.{lora_down_key}.weight"
1735+
converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_A.weight"
1736+
if original_key in original_state_dict:
1737+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1738+
1739+
original_key = f"{img_theirs}.{lora_up_key}.weight"
1740+
converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_B.weight"
1741+
if original_key in original_state_dict:
1742+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1743+
17121744
if len(original_state_dict) > 0:
17131745
diff = all(".diff" in k for k in original_state_dict)
17141746
if diff:

0 commit comments

Comments
 (0)