From 4f264c93ce20c0e18ae1aa1bbc4e11173c86cd37 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 19 May 2025 15:04:17 +0530 Subject: [PATCH 1/4] start supporting kijai wan lora. --- .../loaders/lora_conversion_utils.py | 69 ++++++++++++++++--- 1 file changed, 59 insertions(+), 10 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index af7547f45198..931060c38361 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1596,48 +1596,97 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): converted_state_dict = {} original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()} - num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict}) + num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict if "blocks." in k}) is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict) + lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down" + lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up" + + diff_keys = [k for k in original_state_dict if k.endswith(("diff_b", "diff"))] + if diff_keys: + for diff_k in diff_keys: + param = original_state_dict[diff_k] + all_zero = torch.all(param == 0).item() + if all_zero: + logger.debug(f"Removed {diff_k} key from the state dict as it's all zeros.") + original_state_dict.pop(diff_k) for i in range(num_blocks): # Self-attention for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.self_attn.{o}.lora_A.weight" + f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight" ) converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.self_attn.{o}.lora_B.weight" + f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight" ) # Cross-attention for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.lora_A.weight" + f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" ) converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.lora_B.weight" + f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" ) if is_i2v_lora: for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.lora_A.weight" + f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" ) converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.lora_B.weight" + f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" ) # FFN for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]): converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.{o}.lora_A.weight" + f"blocks.{i}.{o}.{lora_down_key}.weight" ) converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.{o}.lora_B.weight" + f"blocks.{i}.{o}.{lora_up_key}.weight" ) + # Remaining. + if original_state_dict: + if any("time_projection" in k for k in original_state_dict): + converted_state_dict["condition_embedder.time_proj.lora_A.weight"] = original_state_dict.pop( + f"time_projection.1.{lora_down_key}.weight" + ) + converted_state_dict["condition_embedder.time_proj.lora_B.weight"] = original_state_dict.pop( + f"time_projection.1.{lora_up_key}.weight" + ) + + if any("head.head" in k for k in state_dict): + converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop( + f"head.head.{lora_down_key}.weight" + ) + converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight") + + for text_time in ["text_embedding", "time_embedding"]: + if any(text_time in k for k in original_state_dict): + for b_n in [0, 2]: + diffusers_b_n = 1 if b_n == 0 else 2 + diffusers_name = ( + "condition_embedder.text_embedder" + if text_time == "text_embedding" + else "condition_embedder.time_embedder" + ) + if any(f"{text_time}.{b_n}" in k for k in original_state_dict): + converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_A.weight"] = ( + original_state_dict.pop(f"{text_time}.{b_n}.{lora_down_key}.weight") + ) + converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.weight"] = ( + original_state_dict.pop(f"{text_time}.{b_n}.{lora_up_key}.weight") + ) + if len(original_state_dict) > 0: - raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}") + diff = all(".diff" in k for k in original_state_dict) + if diff: + diff_keys = {k for k in original_state_dict if k.endswith(".diff")} + print(f"{len(diff_keys)}") + if not diff: + raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}") for key in list(converted_state_dict.keys()): converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) From d400b256a224388af3b0173d841d165e2eabd5c5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 19 May 2025 16:27:34 +0530 Subject: [PATCH 2/4] diff_b keys. --- .../loaders/lora_conversion_utils.py | 31 +++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 3a189af6466b..7619b6002bf8 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1601,7 +1601,7 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down" lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up" - diff_keys = [k for k in original_state_dict if k.endswith(("diff_b", "diff"))] + diff_keys = [k for k in original_state_dict if k.endswith((".diff_b", ".diff"))] if diff_keys: for diff_k in diff_keys: param = original_state_dict[diff_k] @@ -1610,6 +1610,9 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): logger.debug(f"Removed {diff_k} key from the state dict as it's all zeros.") original_state_dict.pop(diff_k) + # For the `diff_b` keys, we treat them as lora_Bias. + # https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias + for i in range(num_blocks): # Self-attention for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): @@ -1619,6 +1622,10 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop( f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight" ) + if f"blocks.{i}.self_attn.{o}.diff_b" in original_state_dict: + converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.bias"] = original_state_dict.pop( + f"blocks.{i}.self_attn.{o}.diff_b" + ) # Cross-attention for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): @@ -1628,8 +1635,13 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" ) + if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict: + converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop( + f"blocks.{i}.cross_attn.{o}.diff_b" + ) if is_i2v_lora: + # TODO: `diff_b` for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop( f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" @@ -1646,6 +1658,10 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop( f"blocks.{i}.{o}.{lora_up_key}.weight" ) + if f"blocks.{i}.{o}.diff_b" in original_state_dict: + converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.bias"] = original_state_dict.pop( + f"blocks.{i}.{o}.diff_b" + ) # Remaining. if original_state_dict: @@ -1656,12 +1672,18 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): converted_state_dict["condition_embedder.time_proj.lora_B.weight"] = original_state_dict.pop( f"time_projection.1.{lora_up_key}.weight" ) + if "time_projection.1.diff_b" in original_state_dict: + converted_state_dict["condition_embedder.time_proj.lora_B.bias"] = original_state_dict.pop( + "time_projection.1.diff_b" + ) if any("head.head" in k for k in state_dict): converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop( f"head.head.{lora_down_key}.weight" ) converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight") + if "head.head.diff_b" in original_state_dict: + converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b") for text_time in ["text_embedding", "time_embedding"]: if any(text_time in k for k in original_state_dict): @@ -1679,12 +1701,17 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.weight"] = ( original_state_dict.pop(f"{text_time}.{b_n}.{lora_up_key}.weight") ) + if f"{text_time}.{b_n}.diff_b" in original_state_dict: + converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.bias"] = ( + original_state_dict.pop(f"{text_time}.{b_n}.diff_b") + ) if len(original_state_dict) > 0: diff = all(".diff" in k for k in original_state_dict) if diff: diff_keys = {k for k in original_state_dict if k.endswith(".diff")} - print(f"{len(diff_keys)}") + assert all("lora" not in k for k in diff_keys) + print(f"{(diff_keys)}") if not diff: raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}") From 96a665b444c60fe4b88cfff226f37347b20dd68f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 19 May 2025 20:21:53 +0530 Subject: [PATCH 3/4] Apply suggestions from code review Co-authored-by: Aryan --- src/diffusers/loaders/lora_conversion_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 7619b6002bf8..e58b09eba9d9 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1610,7 +1610,7 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): logger.debug(f"Removed {diff_k} key from the state dict as it's all zeros.") original_state_dict.pop(diff_k) - # For the `diff_b` keys, we treat them as lora_Bias. + # For the `diff_b` keys, we treat them as lora_bias. # https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias for i in range(num_blocks): @@ -1711,8 +1711,7 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): if diff: diff_keys = {k for k in original_state_dict if k.endswith(".diff")} assert all("lora" not in k for k in diff_keys) - print(f"{(diff_keys)}") - if not diff: + else: raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}") for key in list(converted_state_dict.keys()): From 3c57997c6a5e76af767ea589954132375f926674 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 19 May 2025 20:30:17 +0530 Subject: [PATCH 4/4] merge ready --- src/diffusers/loaders/lora_conversion_utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index e58b09eba9d9..5b12c3aca84d 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1641,7 +1641,6 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): ) if is_i2v_lora: - # TODO: `diff_b` for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop( f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" @@ -1649,6 +1648,10 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" ) + if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict: + converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop( + f"blocks.{i}.cross_attn.{o}.diff_b" + ) # FFN for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]): @@ -1710,7 +1713,12 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): diff = all(".diff" in k for k in original_state_dict) if diff: diff_keys = {k for k in original_state_dict if k.endswith(".diff")} - assert all("lora" not in k for k in diff_keys) + if not all("lora" not in k for k in diff_keys): + raise ValueError + logger.info( + "The remaining `state_dict` contains `diff` keys which we do not handle yet. If you see performance issues, please file an issue: " + "https://github.com/huggingface/diffusers//issues/new" + ) else: raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")