@@ -339,93 +339,93 @@ def _load_lora_into_text_encoder(
339339    # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), 
340340    # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as 
341341    # their prefixes. 
342-     keys  =  list (state_dict .keys ())
343342    prefix  =  text_encoder_name  if  prefix  is  None  else  prefix 
344343
345-     # Safe prefix to check with. 
346-     if  any (text_encoder_name  in  key  for  key  in  keys ):
347-         # Load the layers corresponding to text encoder and make necessary adjustments. 
348-         text_encoder_keys  =  [k  for  k  in  keys  if  k .startswith (prefix ) and  k .split ("." )[0 ] ==  prefix ]
349-         text_encoder_lora_state_dict  =  {
350-             k .replace (f"{ prefix }  , "" ): v  for  k , v  in  state_dict .items () if  k  in  text_encoder_keys 
351-         }
344+     # Load the layers corresponding to text encoder and make necessary adjustments. 
345+     if  prefix  is  not None :
346+         state_dict  =  {k .replace (f"{ prefix }  , "" ): v  for  k , v  in  state_dict .items () if  k .startswith (f"{ prefix }  )}
347+ 
348+     if  len (state_dict ) >  0 :
349+         logger .info (f"Loading { prefix }  )
350+         rank  =  {}
351+         state_dict  =  convert_state_dict_to_diffusers (state_dict )
352+ 
353+         # convert state dict 
354+         state_dict  =  convert_state_dict_to_peft (state_dict )
355+ 
356+         for  name , _  in  text_encoder_attn_modules (text_encoder ):
357+             for  module  in  ("out_proj" , "q_proj" , "k_proj" , "v_proj" ):
358+                 rank_key  =  f"{ name } { module }  
359+                 if  rank_key  not  in state_dict :
360+                     continue 
361+                 rank [rank_key ] =  state_dict [rank_key ].shape [1 ]
362+ 
363+         for  name , _  in  text_encoder_mlp_modules (text_encoder ):
364+             for  module  in  ("fc1" , "fc2" ):
365+                 rank_key  =  f"{ name } { module }  
366+                 if  rank_key  not  in state_dict :
367+                     continue 
368+                 rank [rank_key ] =  state_dict [rank_key ].shape [1 ]
369+ 
370+         if  network_alphas  is  not None :
371+             alpha_keys  =  [k  for  k  in  network_alphas .keys () if  k .startswith (prefix ) and  k .split ("." )[0 ] ==  prefix ]
372+             network_alphas  =  {k .replace (f"{ prefix }  , "" ): v  for  k , v  in  network_alphas .items () if  k  in  alpha_keys }
373+ 
374+         lora_config_kwargs  =  get_peft_kwargs (rank , network_alphas , state_dict , is_unet = False )
375+ 
376+         if  "use_dora"  in  lora_config_kwargs :
377+             if  lora_config_kwargs ["use_dora" ]:
378+                 if  is_peft_version ("<" , "0.9.0" ):
379+                     raise  ValueError (
380+                         "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." 
381+                     )
382+             else :
383+                 if  is_peft_version ("<" , "0.9.0" ):
384+                     lora_config_kwargs .pop ("use_dora" )
385+ 
386+         if  "lora_bias"  in  lora_config_kwargs :
387+             if  lora_config_kwargs ["lora_bias" ]:
388+                 if  is_peft_version ("<=" , "0.13.2" ):
389+                     raise  ValueError (
390+                         "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." 
391+                     )
392+             else :
393+                 if  is_peft_version ("<=" , "0.13.2" ):
394+                     lora_config_kwargs .pop ("lora_bias" )
352395
353-         if  len (text_encoder_lora_state_dict ) >  0 :
354-             logger .info (f"Loading { prefix }  )
355-             rank  =  {}
356-             text_encoder_lora_state_dict  =  convert_state_dict_to_diffusers (text_encoder_lora_state_dict )
357- 
358-             # convert state dict 
359-             text_encoder_lora_state_dict  =  convert_state_dict_to_peft (text_encoder_lora_state_dict )
360- 
361-             for  name , _  in  text_encoder_attn_modules (text_encoder ):
362-                 for  module  in  ("out_proj" , "q_proj" , "k_proj" , "v_proj" ):
363-                     rank_key  =  f"{ name } { module }  
364-                     if  rank_key  not  in text_encoder_lora_state_dict :
365-                         continue 
366-                     rank [rank_key ] =  text_encoder_lora_state_dict [rank_key ].shape [1 ]
367- 
368-             for  name , _  in  text_encoder_mlp_modules (text_encoder ):
369-                 for  module  in  ("fc1" , "fc2" ):
370-                     rank_key  =  f"{ name } { module }  
371-                     if  rank_key  not  in text_encoder_lora_state_dict :
372-                         continue 
373-                     rank [rank_key ] =  text_encoder_lora_state_dict [rank_key ].shape [1 ]
374- 
375-             if  network_alphas  is  not None :
376-                 alpha_keys  =  [k  for  k  in  network_alphas .keys () if  k .startswith (prefix ) and  k .split ("." )[0 ] ==  prefix ]
377-                 network_alphas  =  {k .replace (f"{ prefix }  , "" ): v  for  k , v  in  network_alphas .items () if  k  in  alpha_keys }
378- 
379-             lora_config_kwargs  =  get_peft_kwargs (rank , network_alphas , text_encoder_lora_state_dict , is_unet = False )
380- 
381-             if  "use_dora"  in  lora_config_kwargs :
382-                 if  lora_config_kwargs ["use_dora" ]:
383-                     if  is_peft_version ("<" , "0.9.0" ):
384-                         raise  ValueError (
385-                             "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." 
386-                         )
387-                 else :
388-                     if  is_peft_version ("<" , "0.9.0" ):
389-                         lora_config_kwargs .pop ("use_dora" )
390- 
391-             if  "lora_bias"  in  lora_config_kwargs :
392-                 if  lora_config_kwargs ["lora_bias" ]:
393-                     if  is_peft_version ("<=" , "0.13.2" ):
394-                         raise  ValueError (
395-                             "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." 
396-                         )
397-                 else :
398-                     if  is_peft_version ("<=" , "0.13.2" ):
399-                         lora_config_kwargs .pop ("lora_bias" )
396+         lora_config  =  LoraConfig (** lora_config_kwargs )
400397
401-             lora_config  =  LoraConfig (** lora_config_kwargs )
398+         # adapter_name 
399+         if  adapter_name  is  None :
400+             adapter_name  =  get_adapter_name (text_encoder )
402401
403-             # adapter_name 
404-             if  adapter_name  is  None :
405-                 adapter_name  =  get_adapter_name (text_encoder )
402+         is_model_cpu_offload , is_sequential_cpu_offload  =  _func_optionally_disable_offloading (_pipeline )
406403
407-             is_model_cpu_offload , is_sequential_cpu_offload  =  _func_optionally_disable_offloading (_pipeline )
404+         # inject LoRA layers and load the state dict 
405+         # in transformers we automatically check whether the adapter name is already in use or not 
406+         text_encoder .load_adapter (
407+             adapter_name = adapter_name ,
408+             adapter_state_dict = state_dict ,
409+             peft_config = lora_config ,
410+             ** peft_kwargs ,
411+         )
408412
409-             # inject LoRA layers and load the state dict 
410-             # in transformers we automatically check whether the adapter name is already in use or not 
411-             text_encoder .load_adapter (
412-                 adapter_name = adapter_name ,
413-                 adapter_state_dict = text_encoder_lora_state_dict ,
414-                 peft_config = lora_config ,
415-                 ** peft_kwargs ,
416-             )
413+         # scale LoRA layers with `lora_scale` 
414+         scale_lora_layers (text_encoder , weight = lora_scale )
417415
418-             # scale LoRA layers with `lora_scale` 
419-             scale_lora_layers (text_encoder , weight = lora_scale )
416+         text_encoder .to (device = text_encoder .device , dtype = text_encoder .dtype )
420417
421-             text_encoder .to (device = text_encoder .device , dtype = text_encoder .dtype )
418+         # Offload back. 
419+         if  is_model_cpu_offload :
420+             _pipeline .enable_model_cpu_offload ()
421+         elif  is_sequential_cpu_offload :
422+             _pipeline .enable_sequential_cpu_offload ()
423+         # Unsafe code /> 
422424
423-             # Offload back. 
424-             if  is_model_cpu_offload :
425-                 _pipeline .enable_model_cpu_offload ()
426-             elif  is_sequential_cpu_offload :
427-                 _pipeline .enable_sequential_cpu_offload ()
428-             # Unsafe code /> 
425+     if  prefix  is  not None  and  not  state_dict :
426+         logger .info (
427+             f"No LoRA keys associated to { text_encoder .__class__ .__name__ } { prefix = }  
428+         )
429429
430430
431431def  _func_optionally_disable_offloading (_pipeline ):
0 commit comments