@@ -186,6 +186,55 @@ def test_lora_expansion_works_for_absent_keys(self):
186186 "LoRA should lead to different results." ,
187187 )
188188
189+ with tempfile .TemporaryDirectory () as tmpdirname :
190+ denoiser_state_dict = get_peft_model_state_dict (pipe .transformer )
191+ self .pipeline_class .save_lora_weights (tmpdirname , transformer_lora_layers = denoiser_state_dict )
192+
193+ self .assertTrue (os .path .isfile (os .path .join (tmpdirname , "pytorch_lora_weights.safetensors" )))
194+ pipe .unload_lora_weights ()
195+ pipe .load_lora_weights (os .path .join (tmpdirname , "pytorch_lora_weights.safetensors" ), adapter_name = "one" )
196+
197+ # Modify the state dict to exclude "x_embedder" related LoRA params.
198+ lora_state_dict = safetensors .torch .load_file (os .path .join (tmpdirname , "pytorch_lora_weights.safetensors" ))
199+ lora_state_dict_without_xembedder = {k : v for k , v in lora_state_dict .items () if "x_embedder" not in k }
200+
201+ pipe .load_lora_weights (lora_state_dict_without_xembedder , adapter_name = "two" )
202+ pipe .set_adapters (["one" , "two" ])
203+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in transformer" )
204+ images_lora_with_absent_keys = pipe (** inputs , generator = torch .manual_seed (0 )).images
205+
206+ self .assertFalse (
207+ np .allclose (images_lora , images_lora_with_absent_keys , atol = 1e-3 , rtol = 1e-3 ),
208+ "Different LoRAs should lead to different results." ,
209+ )
210+ self .assertFalse (
211+ np .allclose (output_no_lora , images_lora_with_absent_keys , atol = 1e-3 , rtol = 1e-3 ),
212+ "LoRA should lead to different results." ,
213+ )
214+
215+ def test_lora_expansion_works_for_extra_keys (self ):
216+ components , _ , denoiser_lora_config = self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
217+ pipe = self .pipeline_class (** components )
218+ pipe = pipe .to (torch_device )
219+ pipe .set_progress_bar_config (disable = None )
220+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
221+
222+ output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 )).images
223+ self .assertTrue (output_no_lora .shape == self .output_shape )
224+
225+ # Modify the config to have a layer which won't be present in the first LoRA we will load.
226+ modified_denoiser_lora_config = copy .deepcopy (denoiser_lora_config )
227+ modified_denoiser_lora_config .target_modules .add ("x_embedder" )
228+
229+ pipe .transformer .add_adapter (modified_denoiser_lora_config )
230+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in transformer" )
231+
232+ images_lora = pipe (** inputs , generator = torch .manual_seed (0 )).images
233+ self .assertFalse (
234+ np .allclose (images_lora , output_no_lora , atol = 1e-3 , rtol = 1e-3 ),
235+ "LoRA should lead to different results." ,
236+ )
237+
189238 with tempfile .TemporaryDirectory () as tmpdirname :
190239 denoiser_state_dict = get_peft_model_state_dict (pipe .transformer )
191240 self .pipeline_class .save_lora_weights (tmpdirname , transformer_lora_layers = denoiser_state_dict )
0 commit comments