4444from  diffusers .utils  import  (
4545    SAFE_WEIGHTS_INDEX_NAME ,
4646    WEIGHTS_INDEX_NAME ,
47+     is_peft_available ,
4748    is_torch_npu_available ,
4849    is_xformers_available ,
4950    logging ,
5354    CaptureLogger ,
5455    get_python_version ,
5556    is_torch_compile ,
57+     require_peft_backend ,
5658    require_torch_2 ,
5759    require_torch_accelerator_with_training ,
5860    require_torch_gpu ,
6567from  ..others .test_utils  import  TOKEN , USER , is_staging_test 
6668
6769
70+ if  is_peft_available ():
71+     from  peft  import  LoraConfig 
72+     from  peft .tuners .tuners_utils  import  BaseTunerLayer 
73+ 
74+     from  diffusers .loaders  import  PeftAdapterMixin 
75+ 
76+ 
6877def  caculate_expected_num_shards (index_map_path ):
6978    with  open (index_map_path ) as  f :
7079        weight_map_dict  =  json .load (f )["weight_map" ]
@@ -74,6 +83,16 @@ def caculate_expected_num_shards(index_map_path):
7483    return  expected_num_shards 
7584
7685
86+ def  check_if_lora_correctly_set (model ) ->  bool :
87+     """ 
88+     Checks if the LoRA layers are correctly set with peft 
89+     """ 
90+     for  module  in  model .modules ():
91+         if  isinstance (module , BaseTunerLayer ):
92+             return  True 
93+     return  False 
94+ 
95+ 
7796# Will be run via run_test_in_subprocess 
7897def  _test_from_save_pretrained_dynamo (in_queue , out_queue , timeout ):
7998    error  =  None 
@@ -902,6 +921,69 @@ def test_deprecated_kwargs(self):
902921                " from `_deprecated_kwargs = [<deprecated_argument>]`" 
903922            )
904923
924+     @require_peft_backend  
925+     @parameterized .expand ([True , False ]) 
926+     def  test_load_save_lora_adapter (self , use_dora = False ):
927+         init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
928+         model  =  self .model_class (** init_dict ).to (torch_device )
929+ 
930+         if  not  issubclass (model .__class__ , PeftAdapterMixin ):
931+             return 
932+ 
933+         torch .manual_seed (0 )
934+         output_no_lora  =  model (** inputs_dict ).sample 
935+ 
936+         denoiser_lora_config  =  LoraConfig (
937+             r = 4 ,
938+             lora_alpha = 4 ,
939+             target_modules = ["to_q" , "to_k" , "to_v" , "to_out.0" ],
940+             init_lora_weights = False ,
941+             use_dora = use_dora ,
942+         )
943+         model .add_adapter (denoiser_lora_config )
944+         self .assertTrue (check_if_lora_correctly_set (model ), "LoRA layers not set correctly" )
945+ 
946+         torch .manual_seed (0 )
947+         outputs_with_lora  =  model (** inputs_dict ).sample 
948+ 
949+         self .assertFalse (torch .allclose (output_no_lora , outputs_with_lora , atol = 1e-4 , rtol = 1e-4 ))
950+ 
951+         with  tempfile .TemporaryDirectory () as  tmpdir :
952+             model .save_lora_adapter (tmpdir )
953+             model .unload_lora ()
954+             model .load_lora_adapter (tmpdir , use_safetensors = True )
955+             self .assertTrue (check_if_lora_correctly_set (model ), "LoRA layers not set correctly" )
956+ 
957+         torch .manual_seed (0 )
958+         outputs_with_lora_2  =  model (** inputs_dict ).sample 
959+ 
960+         self .assertFalse (torch .allclose (output_no_lora , outputs_with_lora_2 , atol = 1e-4 , rtol = 1e-4 ))
961+         self .assertTrue (torch .allclose (outputs_with_lora , outputs_with_lora_2 , atol = 1e-4 , rtol = 1e-4 ))
962+ 
963+     def  test_wrong_adapter_name_raises_error (self ):
964+         init_dict , _  =  self .prepare_init_args_and_inputs_for_common ()
965+         model  =  self .model_class (** init_dict ).to (torch_device )
966+ 
967+         if  not  issubclass (model .__class__ , PeftAdapterMixin ):
968+             return 
969+ 
970+         denoiser_lora_config  =  LoraConfig (
971+             r = 4 ,
972+             lora_alpha = 4 ,
973+             target_modules = ["to_q" , "to_k" , "to_v" , "to_out.0" ],
974+             init_lora_weights = False ,
975+             use_dora = False ,
976+         )
977+         model .add_adapter (denoiser_lora_config )
978+         self .assertTrue (check_if_lora_correctly_set (model ), "LoRA layers not set correctly" )
979+ 
980+         with  tempfile .TemporaryDirectory () as  tmpdir :
981+             wrong_name  =  "foo" 
982+             with  self .assertRaises (ValueError ) as  err_context :
983+                 model .save_lora_adapter (tmpdir , adapter_name = wrong_name )
984+ 
985+             self .assertTrue (f"Adapter name { wrong_name }   in  str (err_context .exception ))
986+ 
905987    @require_torch_gpu  
906988    def  test_cpu_offload (self ):
907989        config , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
0 commit comments