diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 954628f2d33d..5b12c3aca84d 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1596,48 +1596,131 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): converted_state_dict = {} original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()} - num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict}) + num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict if "blocks." in k}) is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict) + lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down" + lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up" + + diff_keys = [k for k in original_state_dict if k.endswith((".diff_b", ".diff"))] + if diff_keys: + for diff_k in diff_keys: + param = original_state_dict[diff_k] + all_zero = torch.all(param == 0).item() + if all_zero: + logger.debug(f"Removed {diff_k} key from the state dict as it's all zeros.") + original_state_dict.pop(diff_k) + + # For the `diff_b` keys, we treat them as lora_bias. + # https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias for i in range(num_blocks): # Self-attention for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.self_attn.{o}.lora_A.weight" + f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight" ) converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.self_attn.{o}.lora_B.weight" + f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight" ) + if f"blocks.{i}.self_attn.{o}.diff_b" in original_state_dict: + converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.bias"] = original_state_dict.pop( + f"blocks.{i}.self_attn.{o}.diff_b" + ) # Cross-attention for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.lora_A.weight" + f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" ) converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.lora_B.weight" + f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" ) + if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict: + converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop( + f"blocks.{i}.cross_attn.{o}.diff_b" + ) if is_i2v_lora: for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.lora_A.weight" + f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" ) converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.lora_B.weight" + f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" ) + if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict: + converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop( + f"blocks.{i}.cross_attn.{o}.diff_b" + ) # FFN for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]): converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.{o}.lora_A.weight" + f"blocks.{i}.{o}.{lora_down_key}.weight" ) converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.{o}.lora_B.weight" + f"blocks.{i}.{o}.{lora_up_key}.weight" ) + if f"blocks.{i}.{o}.diff_b" in original_state_dict: + converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.bias"] = original_state_dict.pop( + f"blocks.{i}.{o}.diff_b" + ) + + # Remaining. + if original_state_dict: + if any("time_projection" in k for k in original_state_dict): + converted_state_dict["condition_embedder.time_proj.lora_A.weight"] = original_state_dict.pop( + f"time_projection.1.{lora_down_key}.weight" + ) + converted_state_dict["condition_embedder.time_proj.lora_B.weight"] = original_state_dict.pop( + f"time_projection.1.{lora_up_key}.weight" + ) + if "time_projection.1.diff_b" in original_state_dict: + converted_state_dict["condition_embedder.time_proj.lora_B.bias"] = original_state_dict.pop( + "time_projection.1.diff_b" + ) + + if any("head.head" in k for k in state_dict): + converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop( + f"head.head.{lora_down_key}.weight" + ) + converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight") + if "head.head.diff_b" in original_state_dict: + converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b") + + for text_time in ["text_embedding", "time_embedding"]: + if any(text_time in k for k in original_state_dict): + for b_n in [0, 2]: + diffusers_b_n = 1 if b_n == 0 else 2 + diffusers_name = ( + "condition_embedder.text_embedder" + if text_time == "text_embedding" + else "condition_embedder.time_embedder" + ) + if any(f"{text_time}.{b_n}" in k for k in original_state_dict): + converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_A.weight"] = ( + original_state_dict.pop(f"{text_time}.{b_n}.{lora_down_key}.weight") + ) + converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.weight"] = ( + original_state_dict.pop(f"{text_time}.{b_n}.{lora_up_key}.weight") + ) + if f"{text_time}.{b_n}.diff_b" in original_state_dict: + converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.bias"] = ( + original_state_dict.pop(f"{text_time}.{b_n}.diff_b") + ) if len(original_state_dict) > 0: - raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}") + diff = all(".diff" in k for k in original_state_dict) + if diff: + diff_keys = {k for k in original_state_dict if k.endswith(".diff")} + if not all("lora" not in k for k in diff_keys): + raise ValueError + logger.info( + "The remaining `state_dict` contains `diff` keys which we do not handle yet. If you see performance issues, please file an issue: " + "https://github.com/huggingface/diffusers//issues/new" + ) + else: + raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}") for key in list(converted_state_dict.keys()): converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)