@@ -294,20 +294,15 @@ def load_lora_into_unet(
294294                "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." 
295295            )
296296
297-         # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), 
298-         # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as 
299-         # their prefixes. 
300-         if  any (k .startswith (f"{ cls .unet_name }  ) for  k  in  state_dict ):
301-             # Load the layers corresponding to UNet. 
302-             logger .info (f"Loading { cls .unet_name }  )
303-             unet .load_lora_adapter (
304-                 state_dict ,
305-                 prefix = cls .unet_name ,
306-                 network_alphas = network_alphas ,
307-                 adapter_name = adapter_name ,
308-                 _pipeline = _pipeline ,
309-                 low_cpu_mem_usage = low_cpu_mem_usage ,
310-             )
297+         # Load the layers corresponding to UNet. 
298+         unet .load_lora_adapter (
299+             state_dict ,
300+             prefix = cls .unet_name ,
301+             network_alphas = network_alphas ,
302+             adapter_name = adapter_name ,
303+             _pipeline = _pipeline ,
304+             low_cpu_mem_usage = low_cpu_mem_usage ,
305+         )
311306
312307    @classmethod  
313308    def  load_lora_into_text_encoder (
@@ -663,18 +658,16 @@ def load_lora_weights(
663658            _pipeline = self ,
664659            low_cpu_mem_usage = low_cpu_mem_usage ,
665660        )
666-         text_encoder_state_dict  =  {k : v  for  k , v  in  state_dict .items () if  "text_encoder."  in  k }
667-         if  len (text_encoder_state_dict ) >  0 :
668-             self .load_lora_into_text_encoder (
669-                 text_encoder_state_dict ,
670-                 network_alphas = network_alphas ,
671-                 text_encoder = self .text_encoder ,
672-                 prefix = "text_encoder" ,
673-                 lora_scale = self .lora_scale ,
674-                 adapter_name = adapter_name ,
675-                 _pipeline = self ,
676-                 low_cpu_mem_usage = low_cpu_mem_usage ,
677-             )
661+         self .load_lora_into_text_encoder (
662+             state_dict ,
663+             network_alphas = network_alphas ,
664+             text_encoder = self .text_encoder ,
665+             prefix = "text_encoder" ,
666+             lora_scale = self .lora_scale ,
667+             adapter_name = adapter_name ,
668+             _pipeline = self ,
669+             low_cpu_mem_usage = low_cpu_mem_usage ,
670+         )
678671
679672        text_encoder_2_state_dict  =  {k : v  for  k , v  in  state_dict .items () if  "text_encoder_2."  in  k }
680673        if  len (text_encoder_2_state_dict ) >  0 :
@@ -839,20 +832,15 @@ def load_lora_into_unet(
839832                "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." 
840833            )
841834
842-         # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), 
843-         # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as 
844-         # their prefixes. 
845-         if  any (k .startswith (f"{ cls .unet_name }  ) for  k  in  state_dict ):
846-             # Load the layers corresponding to UNet. 
847-             logger .info (f"Loading { cls .unet_name }  )
848-             unet .load_lora_adapter (
849-                 state_dict ,
850-                 prefix = cls .unet_name ,
851-                 network_alphas = network_alphas ,
852-                 adapter_name = adapter_name ,
853-                 _pipeline = _pipeline ,
854-                 low_cpu_mem_usage = low_cpu_mem_usage ,
855-             )
835+         # Load the layers corresponding to UNet. 
836+         unet .load_lora_adapter (
837+             state_dict ,
838+             prefix = cls .unet_name ,
839+             network_alphas = network_alphas ,
840+             adapter_name = adapter_name ,
841+             _pipeline = _pipeline ,
842+             low_cpu_mem_usage = low_cpu_mem_usage ,
843+         )
856844
857845    @classmethod  
858846    # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder 
@@ -1294,43 +1282,35 @@ def load_lora_weights(
12941282        if  not  is_correct_format :
12951283            raise  ValueError ("Invalid LoRA checkpoint." )
12961284
1297-         transformer_state_dict  =  {k : v  for  k , v  in  state_dict .items () if  "transformer."  in  k }
1298-         if  len (transformer_state_dict ) >  0 :
1299-             self .load_lora_into_transformer (
1300-                 state_dict ,
1301-                 transformer = getattr (self , self .transformer_name )
1302-                 if  not  hasattr (self , "transformer" )
1303-                 else  self .transformer ,
1304-                 adapter_name = adapter_name ,
1305-                 _pipeline = self ,
1306-                 low_cpu_mem_usage = low_cpu_mem_usage ,
1307-             )
1285+         self .load_lora_into_transformer (
1286+             state_dict ,
1287+             transformer = getattr (self , self .transformer_name ) if  not  hasattr (self , "transformer" ) else  self .transformer ,
1288+             adapter_name = adapter_name ,
1289+             _pipeline = self ,
1290+             low_cpu_mem_usage = low_cpu_mem_usage ,
1291+         )
13081292
1309-         text_encoder_state_dict  =  {k : v  for  k , v  in  state_dict .items () if  "text_encoder."  in  k }
1310-         if  len (text_encoder_state_dict ) >  0 :
1311-             self .load_lora_into_text_encoder (
1312-                 text_encoder_state_dict ,
1313-                 network_alphas = None ,
1314-                 text_encoder = self .text_encoder ,
1315-                 prefix = "text_encoder" ,
1316-                 lora_scale = self .lora_scale ,
1317-                 adapter_name = adapter_name ,
1318-                 _pipeline = self ,
1319-                 low_cpu_mem_usage = low_cpu_mem_usage ,
1320-             )
1293+         self .load_lora_into_text_encoder (
1294+             state_dict ,
1295+             network_alphas = None ,
1296+             text_encoder = self .text_encoder ,
1297+             prefix = "text_encoder" ,
1298+             lora_scale = self .lora_scale ,
1299+             adapter_name = adapter_name ,
1300+             _pipeline = self ,
1301+             low_cpu_mem_usage = low_cpu_mem_usage ,
1302+         )
13211303
1322-         text_encoder_2_state_dict  =  {k : v  for  k , v  in  state_dict .items () if  "text_encoder_2."  in  k }
1323-         if  len (text_encoder_2_state_dict ) >  0 :
1324-             self .load_lora_into_text_encoder (
1325-                 text_encoder_2_state_dict ,
1326-                 network_alphas = None ,
1327-                 text_encoder = self .text_encoder_2 ,
1328-                 prefix = "text_encoder_2" ,
1329-                 lora_scale = self .lora_scale ,
1330-                 adapter_name = adapter_name ,
1331-                 _pipeline = self ,
1332-                 low_cpu_mem_usage = low_cpu_mem_usage ,
1333-             )
1304+         self .load_lora_into_text_encoder (
1305+             state_dict ,
1306+             network_alphas = None ,
1307+             text_encoder = self .text_encoder_2 ,
1308+             prefix = "text_encoder_2" ,
1309+             lora_scale = self .lora_scale ,
1310+             adapter_name = adapter_name ,
1311+             _pipeline = self ,
1312+             low_cpu_mem_usage = low_cpu_mem_usage ,
1313+         )
13341314
13351315    @classmethod  
13361316    def  load_lora_into_transformer (
@@ -1359,7 +1339,6 @@ def load_lora_into_transformer(
13591339            )
13601340
13611341        # Load the layers corresponding to transformer. 
1362-         logger .info (f"Loading { cls .transformer_name }  )
13631342        transformer .load_lora_adapter (
13641343            state_dict ,
13651344            network_alphas = None ,
@@ -1855,7 +1834,7 @@ def load_lora_weights(
18551834            raise  ValueError ("Invalid LoRA checkpoint." )
18561835
18571836        transformer_lora_state_dict  =  {
1858-             k : state_dict .pop (k ) for  k  in  list (state_dict .keys ()) if  "transformer."  in  k  and  "lora"  in  k 
1837+             k : state_dict .get (k ) for  k  in  list (state_dict .keys ()) if  "transformer."  in  k  and  "lora"  in  k 
18591838        }
18601839        transformer_norm_state_dict  =  {
18611840            k : state_dict .pop (k )
@@ -1875,15 +1854,14 @@ def load_lora_weights(
18751854                "To get a comprehensive list of parameter names that were modified, enable debug logging." 
18761855            )
18771856
1878-         if  len (transformer_lora_state_dict ) >  0 :
1879-             self .load_lora_into_transformer (
1880-                 transformer_lora_state_dict ,
1881-                 network_alphas = network_alphas ,
1882-                 transformer = transformer ,
1883-                 adapter_name = adapter_name ,
1884-                 _pipeline = self ,
1885-                 low_cpu_mem_usage = low_cpu_mem_usage ,
1886-             )
1857+         self .load_lora_into_transformer (
1858+             state_dict ,
1859+             network_alphas = network_alphas ,
1860+             transformer = transformer ,
1861+             adapter_name = adapter_name ,
1862+             _pipeline = self ,
1863+             low_cpu_mem_usage = low_cpu_mem_usage ,
1864+         )
18871865
18881866        if  len (transformer_norm_state_dict ) >  0 :
18891867            transformer ._transformer_norm_layers  =  self ._load_norm_into_transformer (
@@ -1892,18 +1870,16 @@ def load_lora_weights(
18921870                discard_original_layers = False ,
18931871            )
18941872
1895-         text_encoder_state_dict  =  {k : v  for  k , v  in  state_dict .items () if  "text_encoder."  in  k }
1896-         if  len (text_encoder_state_dict ) >  0 :
1897-             self .load_lora_into_text_encoder (
1898-                 text_encoder_state_dict ,
1899-                 network_alphas = network_alphas ,
1900-                 text_encoder = self .text_encoder ,
1901-                 prefix = "text_encoder" ,
1902-                 lora_scale = self .lora_scale ,
1903-                 adapter_name = adapter_name ,
1904-                 _pipeline = self ,
1905-                 low_cpu_mem_usage = low_cpu_mem_usage ,
1906-             )
1873+         self .load_lora_into_text_encoder (
1874+             state_dict ,
1875+             network_alphas = network_alphas ,
1876+             text_encoder = self .text_encoder ,
1877+             prefix = "text_encoder" ,
1878+             lora_scale = self .lora_scale ,
1879+             adapter_name = adapter_name ,
1880+             _pipeline = self ,
1881+             low_cpu_mem_usage = low_cpu_mem_usage ,
1882+         )
19071883
19081884    @classmethod  
19091885    def  load_lora_into_transformer (
@@ -1936,17 +1912,13 @@ def load_lora_into_transformer(
19361912            )
19371913
19381914        # Load the layers corresponding to transformer. 
1939-         keys  =  list (state_dict .keys ())
1940-         transformer_present  =  any (key .startswith (cls .transformer_name ) for  key  in  keys )
1941-         if  transformer_present :
1942-             logger .info (f"Loading { cls .transformer_name }  )
1943-             transformer .load_lora_adapter (
1944-                 state_dict ,
1945-                 network_alphas = network_alphas ,
1946-                 adapter_name = adapter_name ,
1947-                 _pipeline = _pipeline ,
1948-                 low_cpu_mem_usage = low_cpu_mem_usage ,
1949-             )
1915+         transformer .load_lora_adapter (
1916+             state_dict ,
1917+             network_alphas = network_alphas ,
1918+             adapter_name = adapter_name ,
1919+             _pipeline = _pipeline ,
1920+             low_cpu_mem_usage = low_cpu_mem_usage ,
1921+         )
19501922
19511923    @classmethod  
19521924    def  _load_norm_into_transformer (
@@ -2837,7 +2809,6 @@ def load_lora_into_transformer(
28372809            )
28382810
28392811        # Load the layers corresponding to transformer. 
2840-         logger .info (f"Loading { cls .transformer_name }  )
28412812        transformer .load_lora_adapter (
28422813            state_dict ,
28432814            network_alphas = None ,
@@ -3145,7 +3116,6 @@ def load_lora_into_transformer(
31453116            )
31463117
31473118        # Load the layers corresponding to transformer. 
3148-         logger .info (f"Loading { cls .transformer_name }  )
31493119        transformer .load_lora_adapter (
31503120            state_dict ,
31513121            network_alphas = None ,
0 commit comments