Skip to content

Commit c665a04

Browse files
committed
revert
1 parent 78a8e1e commit c665a04

File tree

1 file changed

+6
-52
lines changed

1 file changed

+6
-52
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 6 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4810,11 +4810,11 @@ def _maybe_expand_t2v_lora_for_i2v(
48104810
transformer: torch.nn.Module,
48114811
state_dict,
48124812
):
4813-
print("wtf 0", hasattr(transformer, 'vace_blocks'))
4814-
# if transformer.config.image_dim is None:
4815-
# return state_dict
4813+
if transformer.config.image_dim is None:
4814+
return state_dict
48164815

48174816
target_device = transformer.device
4817+
48184818
if any(k.startswith("transformer.blocks.") for k in state_dict):
48194819
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k})
48204820
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
@@ -4833,10 +4833,10 @@ def _maybe_expand_t2v_lora_for_i2v(
48334833
continue
48344834

48354835
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
4836-
state_dict[ref_key_lora_A], device=target_device # Using original ref_key_lora_A
4836+
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device
48374837
)
48384838
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
4839-
state_dict[ref_key_lora_B], device=target_device # Using original ref_key_lora_B
4839+
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device
48404840
)
48414841

48424842
# If the original LoRA had biases (indicated by has_bias)
@@ -4849,52 +4849,7 @@ def _maybe_expand_t2v_lora_for_i2v(
48494849
ref_lora_B_bias_tensor,
48504850
device=target_device,
48514851
)
4852-
4853-
4854-
if hasattr(transformer, 'vace_blocks'):
4855-
print(f"{i}, WTF 0")
4856-
inferred_rank_for_vace = None
4857-
lora_weights_dtype_for_vace = next(iter(transformer.parameters())).dtype # Fallback dtype
4858-
4859-
for k_lora_any, v_lora_tensor_any in state_dict.items():
4860-
if k_lora_any.endswith(".lora_A.weight"):
4861-
inferred_rank_for_vace = v_lora_tensor_any.shape[0]
4862-
lora_weights_dtype_for_vace = v_lora_tensor_any.dtype
4863-
break # Found one, good enough for rank and dtype
4864-
4865-
if inferred_rank_for_vace is not None:
4866-
current_lora_has_bias = any(".lora_B.bias" in k for k in state_dict.keys())
4867-
4868-
for i, vace_block_module_in_model in enumerate(transformer.vace_blocks):
4869-
if hasattr(vace_block_module_in_model, 'proj_out'):
4870-
4871-
proj_out_linear_layer_in_model = vace_block_module_in_model.proj_out
4872-
4873-
vace_lora_A_key = f"vace_blocks.{i}.proj_out.lora_A.weight"
4874-
vace_lora_B_key = f"vace_blocks.{i}.proj_out.lora_B.weight"
4875-
4876-
if vace_lora_A_key not in state_dict:
4877-
print(f"{i}, WTF 1")
4878-
state_dict[vace_lora_A_key] = torch.zeros(
4879-
(inferred_rank_for_vace, proj_out_linear_layer_in_model.in_features),
4880-
device=target_device, dtype=lora_weights_dtype_for_vace
4881-
)
4882-
4883-
if vace_lora_B_key not in state_dict:
4884-
print(f"{i}, WTF 2")
4885-
state_dict[vace_lora_B_key] = torch.zeros(
4886-
(proj_out_linear_layer_in_model.out_features, inferred_rank_for_vace),
4887-
device=target_device, dtype=lora_weights_dtype_for_vace
4888-
)
4889-
4890-
if current_lora_has_bias and proj_out_linear_layer_in_model.bias is not None:
4891-
print(f"{i}, WTF 3")
4892-
vace_lora_B_bias_key = f"vace_blocks.{i}.proj_out.lora_B.bias"
4893-
if vace_lora_B_bias_key not in state_dict:
4894-
state_dict[vace_lora_B_bias_key] = torch.zeros_like(
4895-
proj_out_linear_layer_in_model.bias,
4896-
device=target_device
4897-
)
4852+
print(state_dict.keys)
48984853

48994854
return state_dict
49004855

@@ -4942,7 +4897,6 @@ def load_lora_weights(
49424897
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
49434898
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
49444899
# convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
4945-
print("_maybe_expand_t2v_lora_for_i2v?????????????????")
49464900
state_dict = self._maybe_expand_t2v_lora_for_i2v(
49474901
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
49484902
state_dict=state_dict,

0 commit comments

Comments
 (0)