Skip to content
27 changes: 24 additions & 3 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4813,22 +4813,43 @@ def _maybe_expand_t2v_lora_for_i2v(
if transformer.config.image_dim is None:
return state_dict

target_device = transformer.device

if any(k.startswith("transformer.blocks.") for k in state_dict):
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k})
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
has_bias = any(".lora_B.bias" in k for k in state_dict)

if is_i2v_lora:
return state_dict

for i in range(num_blocks):
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
# These keys should exist if the block `i` was part of the T2V LoRA.
ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"
ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"

if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict:
continue

state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"]
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device
)
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"]
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device
)

# If the original LoRA had biases (indicated by has_bias)
# AND the specific reference bias key exists for this block.

ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias"
if has_bias and ref_key_lora_B_bias in state_dict:
ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias]
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like(
ref_lora_B_bias_tensor,
device=target_device,
)

return state_dict

def load_lora_weights(
Expand Down