diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 7bde2a00be97..d797222e8393 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1596,7 +1596,10 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): converted_state_dict = {} original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()} - num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict if "blocks." in k}) + block_numbers = {int(k.split(".")[1]) for k in original_state_dict if k.startswith("blocks.")} + min_block = min(block_numbers) + max_block = max(block_numbers) + is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict) lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down" lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up" @@ -1622,45 +1625,57 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): # For the `diff_b` keys, we treat them as lora_bias. # https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias - for i in range(num_blocks): + for i in range(min_block, max_block + 1): # Self-attention for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): - converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight" - ) - converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight" - ) - if f"blocks.{i}.self_attn.{o}.diff_b" in original_state_dict: - converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.bias"] = original_state_dict.pop( - f"blocks.{i}.self_attn.{o}.diff_b" - ) + original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight" + converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight" + converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"blocks.{i}.self_attn.{o}.diff_b" + converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) # Cross-attention for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): - converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" - ) - converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" - ) - if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict: - converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.diff_b" - ) + original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" + converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" + converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"blocks.{i}.cross_attn.{o}.diff_b" + converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) if is_i2v_lora: for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): - converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" - ) - converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" - ) - if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict: - converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.diff_b" - ) + original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" + converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" + converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"blocks.{i}.cross_attn.{o}.diff_b" + converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) # FFN for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]): @@ -1674,10 +1689,10 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): if original_key in original_state_dict: converted_state_dict[converted_key] = original_state_dict.pop(original_key) - if f"blocks.{i}.{o}.diff_b" in original_state_dict: - converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.bias"] = original_state_dict.pop( - f"blocks.{i}.{o}.diff_b" - ) + original_key = f"blocks.{i}.{o}.diff_b" + converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) # Remaining. if original_state_dict: