@@ -297,19 +297,15 @@ def load_lora_into_unet(
297297 # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
298298 # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
299299 # their prefixes.
300- keys = list (state_dict .keys ())
301- only_text_encoder = all (key .startswith (cls .text_encoder_name ) for key in keys )
302- if not only_text_encoder :
303- # Load the layers corresponding to UNet.
304- logger .info (f"Loading { cls .unet_name } ." )
305- unet .load_lora_adapter (
306- state_dict ,
307- prefix = cls .unet_name ,
308- network_alphas = network_alphas ,
309- adapter_name = adapter_name ,
310- _pipeline = _pipeline ,
311- low_cpu_mem_usage = low_cpu_mem_usage ,
312- )
300+ logger .info (f"Loading { cls .unet_name } ." )
301+ unet .load_lora_adapter (
302+ state_dict ,
303+ prefix = cls .unet_name ,
304+ network_alphas = network_alphas ,
305+ adapter_name = adapter_name ,
306+ _pipeline = _pipeline ,
307+ low_cpu_mem_usage = low_cpu_mem_usage ,
308+ )
313309
314310 @classmethod
315311 def load_lora_into_text_encoder (
@@ -828,19 +824,15 @@ def load_lora_into_unet(
828824 # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
829825 # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
830826 # their prefixes.
831- keys = list (state_dict .keys ())
832- only_text_encoder = all (key .startswith (cls .text_encoder_name ) for key in keys )
833- if not only_text_encoder :
834- # Load the layers corresponding to UNet.
835- logger .info (f"Loading { cls .unet_name } ." )
836- unet .load_lora_adapter (
837- state_dict ,
838- prefix = cls .unet_name ,
839- network_alphas = network_alphas ,
840- adapter_name = adapter_name ,
841- _pipeline = _pipeline ,
842- low_cpu_mem_usage = low_cpu_mem_usage ,
843- )
827+ logger .info (f"Loading { cls .unet_name } ." )
828+ unet .load_lora_adapter (
829+ state_dict ,
830+ prefix = cls .unet_name ,
831+ network_alphas = network_alphas ,
832+ adapter_name = adapter_name ,
833+ _pipeline = _pipeline ,
834+ low_cpu_mem_usage = low_cpu_mem_usage ,
835+ )
844836
845837 @classmethod
846838 # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -1900,17 +1892,14 @@ def load_lora_into_transformer(
19001892 )
19011893
19021894 # Load the layers corresponding to transformer.
1903- keys = list (state_dict .keys ())
1904- transformer_present = any (key .startswith (cls .transformer_name ) for key in keys )
1905- if transformer_present :
1906- logger .info (f"Loading { cls .transformer_name } ." )
1907- transformer .load_lora_adapter (
1908- state_dict ,
1909- network_alphas = network_alphas ,
1910- adapter_name = adapter_name ,
1911- _pipeline = _pipeline ,
1912- low_cpu_mem_usage = low_cpu_mem_usage ,
1913- )
1895+ logger .info (f"Loading { cls .transformer_name } ." )
1896+ transformer .load_lora_adapter (
1897+ state_dict ,
1898+ network_alphas = network_alphas ,
1899+ adapter_name = adapter_name ,
1900+ _pipeline = _pipeline ,
1901+ low_cpu_mem_usage = low_cpu_mem_usage ,
1902+ )
19141903
19151904 @classmethod
19161905 def _load_norm_into_transformer (
@@ -2495,17 +2484,14 @@ def load_lora_into_transformer(
24952484 )
24962485
24972486 # Load the layers corresponding to transformer.
2498- keys = list (state_dict .keys ())
2499- transformer_present = any (key .startswith (cls .transformer_name ) for key in keys )
2500- if transformer_present :
2501- logger .info (f"Loading { cls .transformer_name } ." )
2502- transformer .load_lora_adapter (
2503- state_dict ,
2504- network_alphas = network_alphas ,
2505- adapter_name = adapter_name ,
2506- _pipeline = _pipeline ,
2507- low_cpu_mem_usage = low_cpu_mem_usage ,
2508- )
2487+ logger .info (f"Loading { cls .transformer_name } ." )
2488+ transformer .load_lora_adapter (
2489+ state_dict ,
2490+ network_alphas = network_alphas ,
2491+ adapter_name = adapter_name ,
2492+ _pipeline = _pipeline ,
2493+ low_cpu_mem_usage = low_cpu_mem_usage ,
2494+ )
25092495
25102496 @classmethod
25112497 # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
0 commit comments