@@ -294,20 +294,15 @@ def load_lora_into_unet(
294294 "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
295295 )
296296
297- # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
298- # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
299- # their prefixes.
300- if any (k .startswith (f"{ cls .unet_name } ." ) for k in state_dict ):
301- # Load the layers corresponding to UNet.
302- logger .info (f"Loading { cls .unet_name } ." )
303- unet .load_lora_adapter (
304- state_dict ,
305- prefix = cls .unet_name ,
306- network_alphas = network_alphas ,
307- adapter_name = adapter_name ,
308- _pipeline = _pipeline ,
309- low_cpu_mem_usage = low_cpu_mem_usage ,
310- )
297+ # Load the layers corresponding to UNet.
298+ unet .load_lora_adapter (
299+ state_dict ,
300+ prefix = cls .unet_name ,
301+ network_alphas = network_alphas ,
302+ adapter_name = adapter_name ,
303+ _pipeline = _pipeline ,
304+ low_cpu_mem_usage = low_cpu_mem_usage ,
305+ )
311306
312307 @classmethod
313308 def load_lora_into_text_encoder (
@@ -663,18 +658,16 @@ def load_lora_weights(
663658 _pipeline = self ,
664659 low_cpu_mem_usage = low_cpu_mem_usage ,
665660 )
666- text_encoder_state_dict = {k : v for k , v in state_dict .items () if "text_encoder." in k }
667- if len (text_encoder_state_dict ) > 0 :
668- self .load_lora_into_text_encoder (
669- text_encoder_state_dict ,
670- network_alphas = network_alphas ,
671- text_encoder = self .text_encoder ,
672- prefix = "text_encoder" ,
673- lora_scale = self .lora_scale ,
674- adapter_name = adapter_name ,
675- _pipeline = self ,
676- low_cpu_mem_usage = low_cpu_mem_usage ,
677- )
661+ self .load_lora_into_text_encoder (
662+ state_dict ,
663+ network_alphas = network_alphas ,
664+ text_encoder = self .text_encoder ,
665+ prefix = "text_encoder" ,
666+ lora_scale = self .lora_scale ,
667+ adapter_name = adapter_name ,
668+ _pipeline = self ,
669+ low_cpu_mem_usage = low_cpu_mem_usage ,
670+ )
678671
679672 text_encoder_2_state_dict = {k : v for k , v in state_dict .items () if "text_encoder_2." in k }
680673 if len (text_encoder_2_state_dict ) > 0 :
@@ -839,20 +832,15 @@ def load_lora_into_unet(
839832 "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
840833 )
841834
842- # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
843- # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
844- # their prefixes.
845- if any (k .startswith (f"{ cls .unet_name } ." ) for k in state_dict ):
846- # Load the layers corresponding to UNet.
847- logger .info (f"Loading { cls .unet_name } ." )
848- unet .load_lora_adapter (
849- state_dict ,
850- prefix = cls .unet_name ,
851- network_alphas = network_alphas ,
852- adapter_name = adapter_name ,
853- _pipeline = _pipeline ,
854- low_cpu_mem_usage = low_cpu_mem_usage ,
855- )
835+ # Load the layers corresponding to UNet.
836+ unet .load_lora_adapter (
837+ state_dict ,
838+ prefix = cls .unet_name ,
839+ network_alphas = network_alphas ,
840+ adapter_name = adapter_name ,
841+ _pipeline = _pipeline ,
842+ low_cpu_mem_usage = low_cpu_mem_usage ,
843+ )
856844
857845 @classmethod
858846 # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -1294,43 +1282,35 @@ def load_lora_weights(
12941282 if not is_correct_format :
12951283 raise ValueError ("Invalid LoRA checkpoint." )
12961284
1297- transformer_state_dict = {k : v for k , v in state_dict .items () if "transformer." in k }
1298- if len (transformer_state_dict ) > 0 :
1299- self .load_lora_into_transformer (
1300- state_dict ,
1301- transformer = getattr (self , self .transformer_name )
1302- if not hasattr (self , "transformer" )
1303- else self .transformer ,
1304- adapter_name = adapter_name ,
1305- _pipeline = self ,
1306- low_cpu_mem_usage = low_cpu_mem_usage ,
1307- )
1285+ self .load_lora_into_transformer (
1286+ state_dict ,
1287+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
1288+ adapter_name = adapter_name ,
1289+ _pipeline = self ,
1290+ low_cpu_mem_usage = low_cpu_mem_usage ,
1291+ )
13081292
1309- text_encoder_state_dict = {k : v for k , v in state_dict .items () if "text_encoder." in k }
1310- if len (text_encoder_state_dict ) > 0 :
1311- self .load_lora_into_text_encoder (
1312- text_encoder_state_dict ,
1313- network_alphas = None ,
1314- text_encoder = self .text_encoder ,
1315- prefix = "text_encoder" ,
1316- lora_scale = self .lora_scale ,
1317- adapter_name = adapter_name ,
1318- _pipeline = self ,
1319- low_cpu_mem_usage = low_cpu_mem_usage ,
1320- )
1293+ self .load_lora_into_text_encoder (
1294+ state_dict ,
1295+ network_alphas = None ,
1296+ text_encoder = self .text_encoder ,
1297+ prefix = "text_encoder" ,
1298+ lora_scale = self .lora_scale ,
1299+ adapter_name = adapter_name ,
1300+ _pipeline = self ,
1301+ low_cpu_mem_usage = low_cpu_mem_usage ,
1302+ )
13211303
1322- text_encoder_2_state_dict = {k : v for k , v in state_dict .items () if "text_encoder_2." in k }
1323- if len (text_encoder_2_state_dict ) > 0 :
1324- self .load_lora_into_text_encoder (
1325- text_encoder_2_state_dict ,
1326- network_alphas = None ,
1327- text_encoder = self .text_encoder_2 ,
1328- prefix = "text_encoder_2" ,
1329- lora_scale = self .lora_scale ,
1330- adapter_name = adapter_name ,
1331- _pipeline = self ,
1332- low_cpu_mem_usage = low_cpu_mem_usage ,
1333- )
1304+ self .load_lora_into_text_encoder (
1305+ state_dict ,
1306+ network_alphas = None ,
1307+ text_encoder = self .text_encoder_2 ,
1308+ prefix = "text_encoder_2" ,
1309+ lora_scale = self .lora_scale ,
1310+ adapter_name = adapter_name ,
1311+ _pipeline = self ,
1312+ low_cpu_mem_usage = low_cpu_mem_usage ,
1313+ )
13341314
13351315 @classmethod
13361316 def load_lora_into_transformer (
@@ -1359,7 +1339,6 @@ def load_lora_into_transformer(
13591339 )
13601340
13611341 # Load the layers corresponding to transformer.
1362- logger .info (f"Loading { cls .transformer_name } ." )
13631342 transformer .load_lora_adapter (
13641343 state_dict ,
13651344 network_alphas = None ,
@@ -1855,7 +1834,7 @@ def load_lora_weights(
18551834 raise ValueError ("Invalid LoRA checkpoint." )
18561835
18571836 transformer_lora_state_dict = {
1858- k : state_dict .pop (k ) for k in list (state_dict .keys ()) if "transformer." in k and "lora" in k
1837+ k : state_dict .get (k ) for k in list (state_dict .keys ()) if "transformer." in k and "lora" in k
18591838 }
18601839 transformer_norm_state_dict = {
18611840 k : state_dict .pop (k )
@@ -1875,15 +1854,14 @@ def load_lora_weights(
18751854 "To get a comprehensive list of parameter names that were modified, enable debug logging."
18761855 )
18771856
1878- if len (transformer_lora_state_dict ) > 0 :
1879- self .load_lora_into_transformer (
1880- transformer_lora_state_dict ,
1881- network_alphas = network_alphas ,
1882- transformer = transformer ,
1883- adapter_name = adapter_name ,
1884- _pipeline = self ,
1885- low_cpu_mem_usage = low_cpu_mem_usage ,
1886- )
1857+ self .load_lora_into_transformer (
1858+ state_dict ,
1859+ network_alphas = network_alphas ,
1860+ transformer = transformer ,
1861+ adapter_name = adapter_name ,
1862+ _pipeline = self ,
1863+ low_cpu_mem_usage = low_cpu_mem_usage ,
1864+ )
18871865
18881866 if len (transformer_norm_state_dict ) > 0 :
18891867 transformer ._transformer_norm_layers = self ._load_norm_into_transformer (
@@ -1892,18 +1870,16 @@ def load_lora_weights(
18921870 discard_original_layers = False ,
18931871 )
18941872
1895- text_encoder_state_dict = {k : v for k , v in state_dict .items () if "text_encoder." in k }
1896- if len (text_encoder_state_dict ) > 0 :
1897- self .load_lora_into_text_encoder (
1898- text_encoder_state_dict ,
1899- network_alphas = network_alphas ,
1900- text_encoder = self .text_encoder ,
1901- prefix = "text_encoder" ,
1902- lora_scale = self .lora_scale ,
1903- adapter_name = adapter_name ,
1904- _pipeline = self ,
1905- low_cpu_mem_usage = low_cpu_mem_usage ,
1906- )
1873+ self .load_lora_into_text_encoder (
1874+ state_dict ,
1875+ network_alphas = network_alphas ,
1876+ text_encoder = self .text_encoder ,
1877+ prefix = "text_encoder" ,
1878+ lora_scale = self .lora_scale ,
1879+ adapter_name = adapter_name ,
1880+ _pipeline = self ,
1881+ low_cpu_mem_usage = low_cpu_mem_usage ,
1882+ )
19071883
19081884 @classmethod
19091885 def load_lora_into_transformer (
@@ -1936,17 +1912,13 @@ def load_lora_into_transformer(
19361912 )
19371913
19381914 # Load the layers corresponding to transformer.
1939- keys = list (state_dict .keys ())
1940- transformer_present = any (key .startswith (cls .transformer_name ) for key in keys )
1941- if transformer_present :
1942- logger .info (f"Loading { cls .transformer_name } ." )
1943- transformer .load_lora_adapter (
1944- state_dict ,
1945- network_alphas = network_alphas ,
1946- adapter_name = adapter_name ,
1947- _pipeline = _pipeline ,
1948- low_cpu_mem_usage = low_cpu_mem_usage ,
1949- )
1915+ transformer .load_lora_adapter (
1916+ state_dict ,
1917+ network_alphas = network_alphas ,
1918+ adapter_name = adapter_name ,
1919+ _pipeline = _pipeline ,
1920+ low_cpu_mem_usage = low_cpu_mem_usage ,
1921+ )
19501922
19511923 @classmethod
19521924 def _load_norm_into_transformer (
@@ -2837,7 +2809,6 @@ def load_lora_into_transformer(
28372809 )
28382810
28392811 # Load the layers corresponding to transformer.
2840- logger .info (f"Loading { cls .transformer_name } ." )
28412812 transformer .load_lora_adapter (
28422813 state_dict ,
28432814 network_alphas = None ,
@@ -3145,7 +3116,6 @@ def load_lora_into_transformer(
31453116 )
31463117
31473118 # Load the layers corresponding to transformer.
3148- logger .info (f"Loading { cls .transformer_name } ." )
31493119 transformer .load_lora_adapter (
31503120 state_dict ,
31513121 network_alphas = None ,
0 commit comments