@@ -2149,3 +2149,51 @@ def check_module(denoiser):
21492149
21502150 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
21512151 pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2152+
2153+ def test_inference_load_delete_load_adapters (self ):
2154+ "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
2155+ for scheduler_cls in self .scheduler_classes :
2156+ components , text_lora_config , denoiser_lora_config = self .get_dummy_components (scheduler_cls )
2157+ pipe = self .pipeline_class (** components )
2158+ pipe = pipe .to (torch_device )
2159+ pipe .set_progress_bar_config (disable = None )
2160+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
2161+
2162+ output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2163+
2164+ if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
2165+ pipe .text_encoder .add_adapter (text_lora_config )
2166+ self .assertTrue (
2167+ check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder"
2168+ )
2169+
2170+ denoiser = pipe .transformer if self .unet_kwargs is None else pipe .unet
2171+ denoiser .add_adapter (denoiser_lora_config )
2172+ self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
2173+
2174+ if self .has_two_text_encoders or self .has_three_text_encoders :
2175+ lora_loadable_components = self .pipeline_class ._lora_loadable_modules
2176+ if "text_encoder_2" in lora_loadable_components :
2177+ pipe .text_encoder_2 .add_adapter (text_lora_config )
2178+ self .assertTrue (
2179+ check_if_lora_correctly_set (pipe .text_encoder_2 ), "Lora not correctly set in text encoder 2"
2180+ )
2181+
2182+ output_adapter_1 = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2183+
2184+ with tempfile .TemporaryDirectory () as tmpdirname :
2185+ modules_to_save = self ._get_modules_to_save (pipe , has_denoiser = True )
2186+ lora_state_dicts = self ._get_lora_state_dicts (modules_to_save )
2187+ self .pipeline_class .save_lora_weights (save_directory = tmpdirname , ** lora_state_dicts )
2188+ self .assertTrue (os .path .isfile (os .path .join (tmpdirname , "pytorch_lora_weights.safetensors" )))
2189+
2190+ # First, delete adapter and compare.
2191+ pipe .delete_adapters (pipe .get_active_adapters ()[0 ])
2192+ output_no_adapter = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2193+ self .assertFalse (np .allclose (output_adapter_1 , output_no_adapter , atol = 1e-3 , rtol = 1e-3 ))
2194+ self .assertTrue (np .allclose (output_no_lora , output_no_adapter , atol = 1e-3 , rtol = 1e-3 ))
2195+
2196+ # Then load adapter and compare.
2197+ pipe .load_lora_weights (tmpdirname )
2198+ output_lora_loaded = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2199+ self .assertTrue (np .allclose (output_adapter_1 , output_lora_loaded , atol = 1e-3 , rtol = 1e-3 ))
0 commit comments