Skip to content

Commit c5a753a

Browse files
committed
fix
1 parent 85a618d commit c5a753a

File tree

1 file changed

+18
-64
lines changed

1 file changed

+18
-64
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 18 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -4813,89 +4813,43 @@ def _maybe_expand_t2v_lora_for_i2v(
48134813
if transformer.config.image_dim is None:
48144814
return state_dict
48154815

4816-
# Determine the target device from the transformer
4817-
try:
4818-
target_device = next(transformer.parameters()).device
4819-
except StopIteration:
4820-
# Fallback if transformer has no parameters (should be rare for a full model)
4821-
# Try to infer from existing tensors in state_dict, else default to CPU
4822-
target_device = torch.device("cpu")
4823-
for v in state_dict.values():
4824-
if torch.is_tensor(v):
4825-
target_device = v.device
4826-
break
4827-
4828-
# Debug prints, can be removed after verification
4829-
logger.debug(f"Target device for new LoRA tensors: {target_device}")
4830-
logger.debug(f"Keys BEFORE _maybe_expand_t2v_lora_for_i2v: {list(state_dict.keys())[:5]}...")
4816+
target_device = transformer.device
48314817

48324818
if any(k.startswith("transformer.blocks.") for k in state_dict):
4833-
block_indices = set()
4834-
# Iterate over a copy of keys if state_dict might be modified during iteration elsewhere (not here though)
4835-
for k in list(state_dict.keys()):
4836-
if k.startswith("transformer.blocks.") and ".attn2.to_k.lora_A.weight" in k:
4837-
try:
4838-
block_idx_str = k.split("blocks.")[1].split(".")[0]
4839-
block_indices.add(int(block_idx_str))
4840-
except (IndexError, ValueError):
4841-
logger.warning(f"Could not parse block index from key: {k}")
4842-
continue
4843-
4844-
if not block_indices:
4845-
logger.debug(
4846-
"No valid blocks found for T2V to I2V expansion referencing based on '.attn2.to_k.lora_A.weight'.")
4847-
return state_dict
4848-
4849-
num_total_blocks = max(block_indices) + 1
4850-
4819+
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k})
48514820
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
4821+
has_bias = any(".lora_B.bias" in k for k in state_dict)
48524822

4853-
# Check for bias keys that would have been converted by _convert_non_diffusers_wan_lora_to_diffusers
4854-
# e.g., 'transformer.blocks.0.attn2.to_k.lora_B.bias'
4855-
# This helps decide if zero biases should be added for the new projections.
4856-
has_bias_key_pattern_in_sd = any(".lora_B.bias" in k for k in state_dict)
4857-
4858-
if is_i2v_lora: # If it's already an I2V LoRA, no expansion needed
4823+
if is_i2v_lora:
48594824
return state_dict
48604825

4861-
logger.info(
4862-
"Adapting a T2V LoRA for I2V model by adding zero-initialized weights for image-specific cross-attention layers."
4863-
)
4864-
for i in range(num_total_blocks):
4865-
# Define reference key patterns carefully. These keys should exist if the block `i` was part of the T2V LoRA.
4866-
ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"
4867-
ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"
4868-
4869-
# Only proceed if the reference LoRA weights for this block exist in the T2V LoRA
4870-
if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict:
4871-
continue
4826+
for i in range(num_blocks):
4827+
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
4828+
# These keys should exist if the block `i` was part of the T2V LoRA.
4829+
ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"
4830+
ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"
48724831

4873-
ref_lora_A_weight = state_dict[ref_key_lora_A]
4874-
ref_lora_B_weight = state_dict[ref_key_lora_B]
4875-
4876-
# Use dtype from reference LoRA tensors, device from transformer
4877-
lora_dtype = ref_lora_A_weight.dtype
4832+
if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict:
4833+
continue
4834+
lora_dtype = ref_lora_A_weight.dtype
48784835

4879-
for _unused_orig_name, diffusers_name_part in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
4880-
# Create new tensors on the transformer's device and with the LoRA's dtype
4881-
state_dict[f"transformer.blocks.{i}.attn2.{diffusers_name_part}.lora_A.weight"] = torch.zeros(
4836+
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros(
48824837
ref_lora_A_weight.shape, device=target_device, dtype=lora_dtype
48834838
)
4884-
state_dict[f"transformer.blocks.{i}.attn2.{diffusers_name_part}.lora_B.weight"] = torch.zeros(
4885-
ref_lora_B_weight.shape, device=target_device, dtype=lora_dtype
4839+
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
4840+
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device
48864841
)
48874842

4888-
# If the original LoRA had biases (indicated by has_bias_key_pattern_in_sd)
4843+
# If the original LoRA had biases (indicated by has_bias)
48894844
# AND the specific reference bias key exists for this block.
4845+
48904846
ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias"
4891-
if has_bias_key_pattern_in_sd and ref_key_lora_B_bias in state_dict:
4847+
if has_bias and ref_key_lora_B_bias in state_dict:
48924848
ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias]
48934849
state_dict[f"transformer.blocks.{i}.attn2.{diffusers_name_part}.lora_B.bias"] = torch.zeros(
48944850
ref_lora_B_bias_tensor.shape, device=target_device, dtype=lora_dtype
48954851
)
48964852

4897-
# new_keys = set(state_dict.keys()) - orig_keys
4898-
# logger.debug(f"Keys ADDED by _maybe_expand_t2v_lora_for_i2v: {new_keys}")
48994853
return state_dict
49004854

49014855
def load_lora_weights(

0 commit comments

Comments
 (0)