Skip to content

Commit f81d4c6

Browse files
committed
revert
1 parent f91dae9 commit f81d4c6

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4810,8 +4810,9 @@ def _maybe_expand_t2v_lora_for_i2v(
48104810
transformer: torch.nn.Module,
48114811
state_dict,
48124812
):
4813-
# if transformer.config.image_dim is None:
4814-
# return state_dict
4813+
print("BEFORE", list(state_dict.keys()))
4814+
if transformer.config.image_dim is None:
4815+
return state_dict
48154816

48164817
target_device = transformer.device
48174818

@@ -4849,7 +4850,20 @@ def _maybe_expand_t2v_lora_for_i2v(
48494850
ref_lora_B_bias_tensor,
48504851
device=target_device,
48514852
)
4852-
print(state_dict.keys)
4853+
4854+
return state_dict
4855+
4856+
@classmethod
4857+
def _maybe_expand_t2v_lora_for_vace(
4858+
cls,
4859+
transformer: torch.nn.Module,
4860+
state_dict,
4861+
):
4862+
4863+
if not hasattr(transformer, 'vace_blocks'):
4864+
return state_dict
4865+
4866+
target_device = transformer.device
48534867

48544868
return state_dict
48554869

@@ -4905,6 +4919,7 @@ def load_lora_weights(
49054919
if not is_correct_format:
49064920
raise ValueError("Invalid LoRA checkpoint.")
49074921

4922+
print("AFTER:", list(state_dict.keys()))
49084923
self.load_lora_into_transformer(
49094924
state_dict,
49104925
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,

0 commit comments

Comments
 (0)