Skip to content

Commit 77b5fa5

Browse files
committed
make it work with lora has both text_encoder & unet
1 parent a226920 commit 77b5fa5

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def _optionally_disable_offloading(cls, _pipeline):
330330
is_model_cpu_offload = False
331331
is_sequential_cpu_offload = False
332332

333-
if _pipeline is not None and _pipeline.hf_device_map is None:
333+
if _pipeline is not None and hasattr(_pipeline, "hf_device_map") and _pipeline.hf_device_map is None:
334334
for _, component in _pipeline.components.items():
335335
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
336336
if not is_model_cpu_offload:

src/diffusers/loaders/lora_pipeline.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -644,24 +644,24 @@ def load_lora_weights(
644644
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
645645
state_dict, network_alphas = self.lora_state_dict(
646646
pretrained_model_name_or_path_or_dict,
647-
unet_config=self.unet.config,
647+
unet_config=self.unet.config if hasattr(self, "unet") else None,
648648
**kwargs,
649649
)
650650

651651
is_correct_format = all("lora" in key for key in state_dict.keys())
652652
if not is_correct_format:
653653
raise ValueError("Invalid LoRA checkpoint.")
654-
655-
self.load_lora_into_unet(
656-
state_dict,
657-
network_alphas=network_alphas,
658-
unet=self.unet,
659-
adapter_name=adapter_name,
660-
_pipeline=self,
661-
low_cpu_mem_usage=low_cpu_mem_usage,
662-
)
654+
if hasattr(self, "unet"):
655+
self.load_lora_into_unet(
656+
state_dict,
657+
network_alphas=network_alphas,
658+
unet=self.unet,
659+
adapter_name=adapter_name,
660+
_pipeline=self,
661+
low_cpu_mem_usage=low_cpu_mem_usage,
662+
)
663663
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
664-
if len(text_encoder_state_dict) > 0:
664+
if len(text_encoder_state_dict) > 0 and hasattr(self, "text_encoder"):
665665
self.load_lora_into_text_encoder(
666666
text_encoder_state_dict,
667667
network_alphas=network_alphas,
@@ -674,7 +674,7 @@ def load_lora_weights(
674674
)
675675

676676
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
677-
if len(text_encoder_2_state_dict) > 0:
677+
if len(text_encoder_2_state_dict) > 0 and hasattr(self, "text_encoder_2"):
678678
self.load_lora_into_text_encoder(
679679
text_encoder_2_state_dict,
680680
network_alphas=network_alphas,

0 commit comments

Comments
 (0)