Skip to content

Commit 85a618d

Browse files
committed
testing
1 parent f7dda02 commit 85a618d

File tree

1 file changed

+73
-14
lines changed

1 file changed

+73
-14
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 73 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4813,30 +4813,89 @@ def _maybe_expand_t2v_lora_for_i2v(
48134813
if transformer.config.image_dim is None:
48144814
return state_dict
48154815

4816-
target_device = transformer.device
4816+
# Determine the target device from the transformer
4817+
try:
4818+
target_device = next(transformer.parameters()).device
4819+
except StopIteration:
4820+
# Fallback if transformer has no parameters (should be rare for a full model)
4821+
# Try to infer from existing tensors in state_dict, else default to CPU
4822+
target_device = torch.device("cpu")
4823+
for v in state_dict.values():
4824+
if torch.is_tensor(v):
4825+
target_device = v.device
4826+
break
4827+
4828+
# Debug prints, can be removed after verification
4829+
logger.debug(f"Target device for new LoRA tensors: {target_device}")
4830+
logger.debug(f"Keys BEFORE _maybe_expand_t2v_lora_for_i2v: {list(state_dict.keys())[:5]}...")
48174831

48184832
if any(k.startswith("transformer.blocks.") for k in state_dict):
4819-
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k})
4833+
block_indices = set()
4834+
# Iterate over a copy of keys if state_dict might be modified during iteration elsewhere (not here though)
4835+
for k in list(state_dict.keys()):
4836+
if k.startswith("transformer.blocks.") and ".attn2.to_k.lora_A.weight" in k:
4837+
try:
4838+
block_idx_str = k.split("blocks.")[1].split(".")[0]
4839+
block_indices.add(int(block_idx_str))
4840+
except (IndexError, ValueError):
4841+
logger.warning(f"Could not parse block index from key: {k}")
4842+
continue
4843+
4844+
if not block_indices:
4845+
logger.debug(
4846+
"No valid blocks found for T2V to I2V expansion referencing based on '.attn2.to_k.lora_A.weight'.")
4847+
return state_dict
4848+
4849+
num_total_blocks = max(block_indices) + 1
4850+
48204851
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
4821-
is_bias = any("bias" in k for k in state_dict)
48224852

4823-
if is_i2v_lora:
4853+
# Check for bias keys that would have been converted by _convert_non_diffusers_wan_lora_to_diffusers
4854+
# e.g., 'transformer.blocks.0.attn2.to_k.lora_B.bias'
4855+
# This helps decide if zero biases should be added for the new projections.
4856+
has_bias_key_pattern_in_sd = any(".lora_B.bias" in k for k in state_dict)
4857+
4858+
if is_i2v_lora: # If it's already an I2V LoRA, no expansion needed
48244859
return state_dict
48254860

4826-
for i in range(num_blocks):
4827-
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
4828-
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
4829-
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device
4861+
logger.info(
4862+
"Adapting a T2V LoRA for I2V model by adding zero-initialized weights for image-specific cross-attention layers."
4863+
)
4864+
for i in range(num_total_blocks):
4865+
# Define reference key patterns carefully. These keys should exist if the block `i` was part of the T2V LoRA.
4866+
ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"
4867+
ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"
4868+
4869+
# Only proceed if the reference LoRA weights for this block exist in the T2V LoRA
4870+
if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict:
4871+
continue
4872+
4873+
ref_lora_A_weight = state_dict[ref_key_lora_A]
4874+
ref_lora_B_weight = state_dict[ref_key_lora_B]
4875+
4876+
# Use dtype from reference LoRA tensors, device from transformer
4877+
lora_dtype = ref_lora_A_weight.dtype
4878+
4879+
for _unused_orig_name, diffusers_name_part in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
4880+
# Create new tensors on the transformer's device and with the LoRA's dtype
4881+
state_dict[f"transformer.blocks.{i}.attn2.{diffusers_name_part}.lora_A.weight"] = torch.zeros(
4882+
ref_lora_A_weight.shape, device=target_device, dtype=lora_dtype
48304883
)
4831-
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
4832-
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device
4884+
state_dict[f"transformer.blocks.{i}.attn2.{diffusers_name_part}.lora_B.weight"] = torch.zeros(
4885+
ref_lora_B_weight.shape, device=target_device, dtype=lora_dtype
48334886
)
4834-
if is_bias:
4835-
state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like(
4836-
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.bias"], device=target_device
4837-
)
48384887

4888+
# If the original LoRA had biases (indicated by has_bias_key_pattern_in_sd)
4889+
# AND the specific reference bias key exists for this block.
4890+
ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias"
4891+
if has_bias_key_pattern_in_sd and ref_key_lora_B_bias in state_dict:
4892+
ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias]
4893+
state_dict[f"transformer.blocks.{i}.attn2.{diffusers_name_part}.lora_B.bias"] = torch.zeros(
4894+
ref_lora_B_bias_tensor.shape, device=target_device, dtype=lora_dtype
4895+
)
48394896

4897+
# new_keys = set(state_dict.keys()) - orig_keys
4898+
# logger.debug(f"Keys ADDED by _maybe_expand_t2v_lora_for_i2v: {new_keys}")
48404899
return state_dict
48414900

48424901
def load_lora_weights(

0 commit comments

Comments
 (0)