Skip to content

Commit ea451d1

Browse files
author
linoy
committed
fix copies
1 parent 729252e commit ea451d1

File tree

1 file changed

+28
-9
lines changed

1 file changed

+28
-9
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5687,15 +5687,34 @@ def load_lora_weights(
56875687
if not is_correct_format:
56885688
raise ValueError("Invalid LoRA checkpoint.")
56895689

5690-
self.load_lora_into_transformer(
5691-
state_dict,
5692-
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
5693-
adapter_name=adapter_name,
5694-
metadata=metadata,
5695-
_pipeline=self,
5696-
low_cpu_mem_usage=low_cpu_mem_usage,
5697-
hotswap=hotswap,
5698-
)
5690+
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
5691+
if load_into_transformer_2:
5692+
if not hasattr(self, "transformer_2"):
5693+
raise ValueError(
5694+
"Cannot load LoRA into transformer_2: transformer_2 is not available for this model"
5695+
"Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
5696+
)
5697+
self.load_lora_into_transformer(
5698+
state_dict,
5699+
transformer=self.transformer_2,
5700+
adapter_name=adapter_name,
5701+
metadata=metadata,
5702+
_pipeline=self,
5703+
low_cpu_mem_usage=low_cpu_mem_usage,
5704+
hotswap=hotswap,
5705+
)
5706+
else:
5707+
self.load_lora_into_transformer(
5708+
state_dict,
5709+
transformer=getattr(self, self.transformer_name)
5710+
if not hasattr(self, "transformer")
5711+
else self.transformer,
5712+
adapter_name=adapter_name,
5713+
metadata=metadata,
5714+
_pipeline=self,
5715+
low_cpu_mem_usage=low_cpu_mem_usage,
5716+
hotswap=hotswap,
5717+
)
56995718

57005719
@classmethod
57015720
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel

0 commit comments

Comments
 (0)