Skip to content

Commit b694ca4

Browse files
committed
updates
1 parent cd88a4b commit b694ca4

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -842,9 +842,7 @@ def load_lora_into_unet(
842842
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
843843
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
844844
# their prefixes.
845-
keys = list(state_dict.keys())
846-
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
847-
if not only_text_encoder:
845+
if any(k.startswith(f"{cls.unet_name}.") for k in state_dict):
848846
# Load the layers corresponding to UNet.
849847
logger.info(f"Loading {cls.unet_name}.")
850848
unet.load_lora_adapter(
@@ -1008,6 +1006,11 @@ def load_lora_into_text_encoder(
10081006
_pipeline.enable_sequential_cpu_offload()
10091007
# Unsafe code />
10101008

1009+
else:
1010+
logger.info(
1011+
f"No LoRA keys found in the provided state dict for {text_encoder.__class__.__name__}. Please open an issue if you think this is unexpected - https://github.com/huggingface/diffusers/issues/new."
1012+
)
1013+
10111014
@classmethod
10121015
def save_lora_weights(
10131016
cls,
@@ -1517,6 +1520,11 @@ def load_lora_into_text_encoder(
15171520
_pipeline.enable_sequential_cpu_offload()
15181521
# Unsafe code />
15191522

1523+
else:
1524+
logger.info(
1525+
f"No LoRA keys found in the provided state dict for {text_encoder.__class__.__name__}. Please open an issue if you think this is unexpected - https://github.com/huggingface/diffusers/issues/new."
1526+
)
1527+
15201528
@classmethod
15211529
def save_lora_weights(
15221530
cls,
@@ -2146,6 +2154,11 @@ def load_lora_into_text_encoder(
21462154
_pipeline.enable_sequential_cpu_offload()
21472155
# Unsafe code />
21482156

2157+
else:
2158+
logger.info(
2159+
f"No LoRA keys found in the provided state dict for {text_encoder.__class__.__name__}. Please open an issue if you think this is unexpected - https://github.com/huggingface/diffusers/issues/new."
2160+
)
2161+
21492162
@classmethod
21502163
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
21512164
def save_lora_weights(
@@ -2580,6 +2593,11 @@ def load_lora_into_text_encoder(
25802593
_pipeline.enable_sequential_cpu_offload()
25812594
# Unsafe code />
25822595

2596+
else:
2597+
logger.info(
2598+
f"No LoRA keys found in the provided state dict for {text_encoder.__class__.__name__}. Please open an issue if you think this is unexpected - https://github.com/huggingface/diffusers/issues/new."
2599+
)
2600+
25832601
@classmethod
25842602
def save_lora_weights(
25852603
cls,

0 commit comments

Comments
 (0)