@@ -1060,10 +1060,10 @@ def test_deprecated_kwargs(self):
10601060 " from `_deprecated_kwargs = [<deprecated_argument>]`"
10611061 )
10621062
1063- @parameterized .expand ([True , False ])
1063+ @parameterized .expand ([( 4 , 4 , True ), ( 4 , 8 , False ), ( 8 , 4 , False ) ])
10641064 @torch .no_grad ()
10651065 @unittest .skipIf (not is_peft_available (), "Only with PEFT" )
1066- def test_save_load_lora_adapter (self , use_dora = False ):
1066+ def test_save_load_lora_adapter (self , rank , lora_alpha , use_dora = False ):
10671067 from peft import LoraConfig
10681068 from peft .utils import get_peft_model_state_dict
10691069
@@ -1079,8 +1079,8 @@ def test_save_load_lora_adapter(self, use_dora=False):
10791079 output_no_lora = model (** inputs_dict , return_dict = False )[0 ]
10801080
10811081 denoiser_lora_config = LoraConfig (
1082- r = 4 ,
1083- lora_alpha = 4 ,
1082+ r = rank ,
1083+ lora_alpha = lora_alpha ,
10841084 target_modules = ["to_q" , "to_k" , "to_v" , "to_out.0" ],
10851085 init_lora_weights = False ,
10861086 use_dora = use_dora ,
@@ -1147,12 +1147,12 @@ def test_wrong_adapter_name_raises_error(self):
11471147
11481148 self .assertTrue (f"Adapter name { wrong_name } not found in the model." in str (err_context .exception ))
11491149
1150+ @parameterized .expand ([(4 , 4 , True ), (4 , 8 , False ), (8 , 4 , False )])
11501151 @torch .no_grad ()
11511152 @unittest .skipIf (not is_peft_available (), "Only with PEFT" )
1152- def test_adapter_metadata_is_loaded_correctly (self ):
1153+ def test_adapter_metadata_is_loaded_correctly (self , rank , lora_alpha , use_dora ):
11531154 from peft import LoraConfig
11541155
1155- from diffusers .loaders .lora_base import LORA_ADAPTER_METADATA_KEY
11561156 from diffusers .loaders .peft import PeftAdapterMixin
11571157
11581158 init_dict , _ = self .prepare_init_args_and_inputs_for_common ()
@@ -1162,11 +1162,11 @@ def test_adapter_metadata_is_loaded_correctly(self):
11621162 return
11631163
11641164 denoiser_lora_config = LoraConfig (
1165- r = 4 ,
1166- lora_alpha = 4 ,
1165+ r = rank ,
1166+ lora_alpha = lora_alpha ,
11671167 target_modules = ["to_q" , "to_k" , "to_v" , "to_out.0" ],
11681168 init_lora_weights = False ,
1169- use_dora = False ,
1169+ use_dora = use_dora ,
11701170 )
11711171 model .add_adapter (denoiser_lora_config )
11721172 metadata = model .peft_config ["default" ].to_dict ()
@@ -1177,15 +1177,12 @@ def test_adapter_metadata_is_loaded_correctly(self):
11771177 model_file = os .path .join (tmpdir , "pytorch_lora_weights.safetensors" )
11781178 self .assertTrue (os .path .isfile (model_file ))
11791179
1180- with safetensors .torch .safe_open (model_file , framework = "pt" , device = "cpu" ) as f :
1181- if hasattr (f , "metadata" ):
1182- parsed_metadata = f .metadata ()
1183- parsed_metadata = {k : v for k , v in parsed_metadata .items () if k != "format" }
1184- self .assertTrue (LORA_ADAPTER_METADATA_KEY in parsed_metadata )
1185- parsed_metadata = {k : v for k , v in parsed_metadata .items () if k != "format" }
1180+ model .unload_lora ()
1181+ self .assertFalse (check_if_lora_correctly_set (model ), "LoRA layers not set correctly" )
11861182
1187- parsed_metadata = json .loads (parsed_metadata [LORA_ADAPTER_METADATA_KEY ])
1188- check_if_dicts_are_equal (parsed_metadata , metadata )
1183+ model .load_lora_adapter (tmpdir , prefix = None , use_safetensors = True )
1184+ parsed_metadata = model .peft_config ["default_0" ].to_dict ()
1185+ check_if_dicts_are_equal (metadata , parsed_metadata )
11891186
11901187 @require_torch_accelerator
11911188 def test_cpu_offload (self ):
0 commit comments