@@ -2149,3 +2149,51 @@ def check_module(denoiser):
21492149
21502150            _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
21512151            pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2152+ 
2153+     def  test_inference_load_delete_load_adapters (self ):
2154+         "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." 
2155+         for  scheduler_cls  in  self .scheduler_classes :
2156+             components , text_lora_config , denoiser_lora_config  =  self .get_dummy_components (scheduler_cls )
2157+             pipe  =  self .pipeline_class (** components )
2158+             pipe  =  pipe .to (torch_device )
2159+             pipe .set_progress_bar_config (disable = None )
2160+             _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
2161+ 
2162+             output_no_lora  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2163+ 
2164+             if  "text_encoder"  in  self .pipeline_class ._lora_loadable_modules :
2165+                 pipe .text_encoder .add_adapter (text_lora_config )
2166+                 self .assertTrue (
2167+                     check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder" 
2168+                 )
2169+ 
2170+             denoiser  =  pipe .transformer  if  self .unet_kwargs  is  None  else  pipe .unet 
2171+             denoiser .add_adapter (denoiser_lora_config )
2172+             self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
2173+ 
2174+             if  self .has_two_text_encoders  or  self .has_three_text_encoders :
2175+                 lora_loadable_components  =  self .pipeline_class ._lora_loadable_modules 
2176+                 if  "text_encoder_2"  in  lora_loadable_components :
2177+                     pipe .text_encoder_2 .add_adapter (text_lora_config )
2178+                     self .assertTrue (
2179+                         check_if_lora_correctly_set (pipe .text_encoder_2 ), "Lora not correctly set in text encoder 2" 
2180+                     )
2181+ 
2182+             output_adapter_1  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2183+ 
2184+             with  tempfile .TemporaryDirectory () as  tmpdirname :
2185+                 modules_to_save  =  self ._get_modules_to_save (pipe , has_denoiser = True )
2186+                 lora_state_dicts  =  self ._get_lora_state_dicts (modules_to_save )
2187+                 self .pipeline_class .save_lora_weights (save_directory = tmpdirname , ** lora_state_dicts )
2188+                 self .assertTrue (os .path .isfile (os .path .join (tmpdirname , "pytorch_lora_weights.safetensors" )))
2189+ 
2190+                 # First, delete adapter and compare. 
2191+                 pipe .delete_adapters (pipe .get_active_adapters ()[0 ])
2192+                 output_no_adapter  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2193+                 self .assertFalse (np .allclose (output_adapter_1 , output_no_adapter , atol = 1e-3 , rtol = 1e-3 ))
2194+                 self .assertTrue (np .allclose (output_no_lora , output_no_adapter , atol = 1e-3 , rtol = 1e-3 ))
2195+ 
2196+                 # Then load adapter and compare. 
2197+                 pipe .load_lora_weights (tmpdirname )
2198+                 output_lora_loaded  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2199+                 self .assertTrue (np .allclose (output_adapter_1 , output_lora_loaded , atol = 1e-3 , rtol = 1e-3 ))
0 commit comments