Skip to content

Commit b3394d4

Browse files
committed
revert
1 parent 26dcfd0 commit b3394d4

File tree

1 file changed

+42
-145
lines changed

1 file changed

+42
-145
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 42 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from typing import Callable, Dict, List, Optional, Union
1717

1818
import torch
19-
import re
2019
from huggingface_hub.utils import validate_hf_hub_args
2120

2221
from ..utils import (
@@ -4806,152 +4805,50 @@ def lora_state_dict(
48064805
return state_dict
48074806

48084807
@classmethod
4809-
def _modified_maybe_expand_t2v_lora( # Renamed for clarity
4810-
# cls, # if it were a classmethod
4811-
transformer: torch.nn.Module,
4812-
state_dict: Dict[str, torch.Tensor],
4813-
lora_filename_for_rank_inference: Optional[str] = None # Optional: for rank hint
4814-
) -> Dict[str, torch.Tensor]:
4808+
def _maybe_expand_t2v_lora_for_i2v(
4809+
cls,
4810+
transformer: torch.nn.Module,
4811+
state_dict,
4812+
):
4813+
if transformer.config.image_dim is None:
4814+
return state_dict
48154815

48164816
target_device = transformer.device
4817-
# Default dtype from transformer, can be refined if LoRA weights have a different one
4818-
lora_weights_dtype = next(iter(transformer.parameters())).dtype
4819-
4820-
# --- Infer LoRA rank and potentially refine dtype from existing LoRA weights ---
4821-
inferred_rank = None
4822-
if state_dict: # If LoRA state_dict already has entries from the T2V LoRA
4823-
for k, v_tensor in state_dict.items():
4824-
if k.endswith(".lora_A.weight"): # Standard LoRA weight key part
4825-
inferred_rank = v_tensor.shape[0] # rank is the output dim of lora_A
4826-
lora_weights_dtype = v_tensor.dtype # Use dtype of existing LoRA weights
4827-
break # Found rank and dtype
4828-
4829-
if inferred_rank is None and lora_filename_for_rank_inference:
4830-
match = re.search(r"rank(\d+)", lora_filename_for_rank_inference, re.IGNORECASE)
4831-
if match:
4832-
inferred_rank = int(match.group(1))
4833-
print(f"INFO: Inferred LoRA rank {inferred_rank} from filename for padding.")
4834-
4835-
# Determine if the original LoRA format (the T2V part) uses biases for lora_B
4836-
lora_format_has_bias = any(".lora_B.bias" in k for k in state_dict.keys())
4837-
4838-
# --- Part 1: Original I2V expansion for standard transformer.blocks ---
4839-
# (Assuming transformer.config and transformer.blocks structure for this part)
4840-
if hasattr(transformer, 'config') and hasattr(transformer.config, 'image_dim') and \
4841-
transformer.config.image_dim is not None and hasattr(transformer, 'blocks'):
4842-
4843-
standard_block_keys_present = any(k.startswith("transformer.blocks.") for k in state_dict)
4844-
4845-
if standard_block_keys_present and inferred_rank is not None:
4846-
num_blocks_in_lora = 0
4847-
block_indices = set()
4848-
for k_lora in state_dict:
4849-
if "transformer.blocks." in k_lora:
4850-
try:
4851-
block_idx_str = k_lora.split("transformer.blocks.")[1].split(".")[0]
4852-
if block_idx_str.isdigit():
4853-
block_indices.add(int(block_idx_str))
4854-
except IndexError:
4855-
pass
4856-
if block_indices:
4857-
num_blocks_in_lora = max(block_indices) + 1
4858-
4859-
is_i2v_lora_standard_blocks = any(
4860-
k.startswith("transformer.blocks.") and "add_k_proj" in k for k in state_dict
4861-
) and any(
4862-
k.startswith("transformer.blocks.") and "add_v_proj" in k for k in state_dict
4863-
)
48644817

4865-
if not is_i2v_lora_standard_blocks and num_blocks_in_lora > 0:
4866-
print(f"INFO: Expanding T2V LoRA for I2V compatibility (standard blocks). Rank: {inferred_rank}")
4867-
for i in range(num_blocks_in_lora):
4868-
# Check if block 'i' relevant parts are in the T2V LoRA
4869-
ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"
4870-
if ref_key_lora_A not in state_dict:
4871-
continue # This block's specific part wasn't in the LoRA.
4872-
4873-
try:
4874-
model_block = transformer.blocks[i]
4875-
# Ensure these target layers exist in the model's standard block
4876-
if not (hasattr(model_block, 'attn2') and \
4877-
hasattr(model_block.attn2, 'add_k_proj') and \
4878-
hasattr(model_block.attn2, 'add_v_proj')):
4879-
continue
4880-
add_k_proj_layer = model_block.attn2.add_k_proj
4881-
add_v_proj_layer = model_block.attn2.add_v_proj
4882-
except (AttributeError, IndexError):
4883-
print(f"WARN: Cannot access standard block {i} or its I2V layers for expansion.")
4884-
continue
4885-
4886-
for proj_name_suffix, model_linear_layer in [("add_k_proj", add_k_proj_layer),
4887-
("add_v_proj", add_v_proj_layer)]:
4888-
if not isinstance(model_linear_layer, nn.Linear): continue
4889-
4890-
lora_A_key = f"transformer.blocks.{i}.attn2.{proj_name_suffix}.lora_A.weight"
4891-
lora_B_key = f"transformer.blocks.{i}.attn2.{proj_name_suffix}.lora_B.weight"
4892-
4893-
if lora_A_key not in state_dict:
4894-
state_dict[lora_A_key] = torch.zeros(
4895-
(inferred_rank, model_linear_layer.in_features),
4896-
device=target_device, dtype=lora_weights_dtype
4897-
)
4898-
if lora_B_key not in state_dict:
4899-
state_dict[lora_B_key] = torch.zeros(
4900-
(model_linear_layer.out_features, inferred_rank),
4901-
device=target_device, dtype=lora_weights_dtype
4902-
)
4903-
4904-
if lora_format_has_bias and model_linear_layer.bias is not None:
4905-
lora_B_bias_key = f"transformer.blocks.{i}.attn2.{proj_name_suffix}.lora_B.bias"
4906-
if lora_B_bias_key not in state_dict:
4907-
state_dict[lora_B_bias_key] = torch.zeros_like(
4908-
model_linear_layer.bias, device=target_device,
4909-
dtype=model_linear_layer.bias.dtype
4910-
)
4911-
elif inferred_rank is None:
4912-
print("INFO: LoRA rank not inferred. Skipping I2V expansion for standard blocks.")
4913-
# else: not standard_block_keys_present or no I2V capability.
4914-
4915-
# --- Part 2: Pad LoRA for WanVACETransformer3DModel vace_blocks.X.proj_out ---
4916-
# Dynamically check for WanVACETransformer3DModel availability for isinstance
4917-
VACEModelClass = globals().get("WanVACETransformer3DModel")
4918-
4919-
if VACEModelClass and isinstance(transformer, VACEModelClass) and hasattr(transformer, 'vace_blocks'):
4920-
if inferred_rank is None:
4921-
print("WARNING: LoRA rank not determined. Skipping VACE block padding for proj_out.")
4922-
else:
4923-
print(f"INFO: Transformer is WanVACE. Padding LoRA for vace_blocks.X.proj_out. Rank: {inferred_rank}")
4924-
for i, vace_block_module in enumerate(transformer.vace_blocks):
4925-
if hasattr(vace_block_module, 'proj_out') and isinstance(vace_block_module.proj_out, nn.Linear):
4926-
proj_out_layer = vace_block_module.proj_out
4927-
4928-
# Keys for the vace_block's proj_out LoRA layers
4929-
# These are the keys PEFT expects in the state_dict *before* adding adapter name context
4930-
lora_A_key = f"vace_blocks.{i}.proj_out.lora_A.weight"
4931-
lora_B_key = f"vace_blocks.{i}.proj_out.lora_B.weight"
4932-
4933-
if lora_A_key not in state_dict:
4934-
state_dict[lora_A_key] = torch.zeros(
4935-
(inferred_rank, proj_out_layer.in_features),
4936-
device=target_device, dtype=lora_weights_dtype
4937-
)
4938-
# print(f"Padded: {lora_A_key}")
4939-
4940-
if lora_B_key not in state_dict:
4941-
state_dict[lora_B_key] = torch.zeros(
4942-
(proj_out_layer.out_features, inferred_rank),
4943-
device=target_device, dtype=lora_weights_dtype
4944-
)
4945-
# print(f"Padded: {lora_B_key}")
4946-
4947-
if lora_format_has_bias and proj_out_layer.bias is not None:
4948-
lora_B_bias_key = f"vace_blocks.{i}.proj_out.lora_B.bias"
4949-
if lora_B_bias_key not in state_dict:
4950-
state_dict[lora_B_bias_key] = torch.zeros_like(
4951-
proj_out_layer.bias, device=target_device, dtype=proj_out_layer.bias.dtype
4952-
)
4953-
# print(f"Padded: {lora_B_bias_key}")
4954-
# else: VACE block 'i' might not have proj_out or it's not Linear.
4818+
if any(k.startswith("transformer.blocks.") for k in state_dict):
4819+
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k})
4820+
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
4821+
has_bias = any(".lora_B.bias" in k for k in state_dict)
4822+
4823+
if is_i2v_lora:
4824+
return state_dict
4825+
4826+
for i in range(num_blocks):
4827+
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
4828+
# These keys should exist if the block `i` was part of the T2V LoRA.
4829+
ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"
4830+
ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"
4831+
4832+
if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict:
4833+
continue
4834+
4835+
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
4836+
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device
4837+
)
4838+
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
4839+
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device
4840+
)
4841+
4842+
# If the original LoRA had biases (indicated by has_bias)
4843+
# AND the specific reference bias key exists for this block.
4844+
4845+
ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias"
4846+
if has_bias and ref_key_lora_B_bias in state_dict:
4847+
ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias]
4848+
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like(
4849+
ref_lora_B_bias_tensor,
4850+
device=target_device,
4851+
)
49554852

49564853
return state_dict
49574854

@@ -5816,4 +5713,4 @@ class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
58165713
def __init__(self, *args, **kwargs):
58175714
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
58185715
deprecate("LoraLoaderMixin", "1.0.0", deprecation_message)
5819-
super().__init__(*args, **kwargs)
5716+
super().__init__(*args, **kwargs)

0 commit comments

Comments
 (0)