Skip to content

Commit e2e3ea0

Browse files
committed
fixes
1 parent 615e372 commit e2e3ea0

File tree

2 files changed

+40
-52
lines changed

2 files changed

+40
-52
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 34 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -297,19 +297,15 @@ def load_lora_into_unet(
297297
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
298298
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
299299
# their prefixes.
300-
keys = list(state_dict.keys())
301-
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
302-
if not only_text_encoder:
303-
# Load the layers corresponding to UNet.
304-
logger.info(f"Loading {cls.unet_name}.")
305-
unet.load_lora_adapter(
306-
state_dict,
307-
prefix=cls.unet_name,
308-
network_alphas=network_alphas,
309-
adapter_name=adapter_name,
310-
_pipeline=_pipeline,
311-
low_cpu_mem_usage=low_cpu_mem_usage,
312-
)
300+
logger.info(f"Loading {cls.unet_name}.")
301+
unet.load_lora_adapter(
302+
state_dict,
303+
prefix=cls.unet_name,
304+
network_alphas=network_alphas,
305+
adapter_name=adapter_name,
306+
_pipeline=_pipeline,
307+
low_cpu_mem_usage=low_cpu_mem_usage,
308+
)
313309

314310
@classmethod
315311
def load_lora_into_text_encoder(
@@ -828,19 +824,15 @@ def load_lora_into_unet(
828824
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
829825
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
830826
# their prefixes.
831-
keys = list(state_dict.keys())
832-
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
833-
if not only_text_encoder:
834-
# Load the layers corresponding to UNet.
835-
logger.info(f"Loading {cls.unet_name}.")
836-
unet.load_lora_adapter(
837-
state_dict,
838-
prefix=cls.unet_name,
839-
network_alphas=network_alphas,
840-
adapter_name=adapter_name,
841-
_pipeline=_pipeline,
842-
low_cpu_mem_usage=low_cpu_mem_usage,
843-
)
827+
logger.info(f"Loading {cls.unet_name}.")
828+
unet.load_lora_adapter(
829+
state_dict,
830+
prefix=cls.unet_name,
831+
network_alphas=network_alphas,
832+
adapter_name=adapter_name,
833+
_pipeline=_pipeline,
834+
low_cpu_mem_usage=low_cpu_mem_usage,
835+
)
844836

845837
@classmethod
846838
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -1900,17 +1892,14 @@ def load_lora_into_transformer(
19001892
)
19011893

19021894
# Load the layers corresponding to transformer.
1903-
keys = list(state_dict.keys())
1904-
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
1905-
if transformer_present:
1906-
logger.info(f"Loading {cls.transformer_name}.")
1907-
transformer.load_lora_adapter(
1908-
state_dict,
1909-
network_alphas=network_alphas,
1910-
adapter_name=adapter_name,
1911-
_pipeline=_pipeline,
1912-
low_cpu_mem_usage=low_cpu_mem_usage,
1913-
)
1895+
logger.info(f"Loading {cls.transformer_name}.")
1896+
transformer.load_lora_adapter(
1897+
state_dict,
1898+
network_alphas=network_alphas,
1899+
adapter_name=adapter_name,
1900+
_pipeline=_pipeline,
1901+
low_cpu_mem_usage=low_cpu_mem_usage,
1902+
)
19141903

19151904
@classmethod
19161905
def _load_norm_into_transformer(
@@ -2495,17 +2484,14 @@ def load_lora_into_transformer(
24952484
)
24962485

24972486
# Load the layers corresponding to transformer.
2498-
keys = list(state_dict.keys())
2499-
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
2500-
if transformer_present:
2501-
logger.info(f"Loading {cls.transformer_name}.")
2502-
transformer.load_lora_adapter(
2503-
state_dict,
2504-
network_alphas=network_alphas,
2505-
adapter_name=adapter_name,
2506-
_pipeline=_pipeline,
2507-
low_cpu_mem_usage=low_cpu_mem_usage,
2508-
)
2487+
logger.info(f"Loading {cls.transformer_name}.")
2488+
transformer.load_lora_adapter(
2489+
state_dict,
2490+
network_alphas=network_alphas,
2491+
adapter_name=adapter_name,
2492+
_pipeline=_pipeline,
2493+
low_cpu_mem_usage=low_cpu_mem_usage,
2494+
)
25092495

25102496
@classmethod
25112497
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder

src/diffusers/loaders/peft.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
253253
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
254254

255255
if prefix is not None:
256-
keys = list(state_dict.keys())
257-
model_keys = [k for k in keys if k.startswith(f"{prefix}.")]
258-
if len(model_keys) > 0:
259-
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys}
256+
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
260257

261258
if len(state_dict) > 0:
262259
if adapter_name in getattr(self, "peft_config", {}):
@@ -369,6 +366,11 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
369366
_pipeline.enable_sequential_cpu_offload()
370367
# Unsafe code />
371368

369+
if prefix is not None and not state_dict:
370+
logger.info(
371+
f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new"
372+
)
373+
372374
def save_lora_adapter(
373375
self,
374376
save_directory,

0 commit comments

Comments
 (0)