From cb051b8a9a89ee2db5013ecef9f37ca6586fee21 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Apr 2025 13:27:25 +0530 Subject: [PATCH 1/4] start cleaning up lora test utils for reusability --- tests/lora/utils.py | 300 ++++++++------------------------------------ 1 file changed, 51 insertions(+), 249 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 8cdb43c9d085..edb419810948 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -260,6 +260,29 @@ def _get_modules_to_save(self, pipe, has_denoiser=False): return modules_to_save + def check_if_adapters_added_correctly(self, pipe, text_lora_config=None, denoiser_lora_config=None): + if text_lora_config is not None: + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + + if denoiser_lora_config is not None: + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + else: + denoiser = None + + if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + return pipe, denoiser + def test_simple_inference(self): """ Tests a simple inference and makes sure it works as expected @@ -289,16 +312,7 @@ def test_simple_inference_with_text_lora(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - - if self.has_two_text_encoders or self.has_three_text_encoders: - lora_loadable_components = self.pipeline_class._lora_loadable_modules - if "text_encoder_2" in lora_loadable_components: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( @@ -381,22 +395,7 @@ def test_low_cpu_mem_usage_with_loading(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -459,16 +458,7 @@ def test_simple_inference_with_text_lora_and_scale(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - - if self.has_two_text_encoders or self.has_three_text_encoders: - lora_loadable_components = self.pipeline_class._lora_loadable_modules - if "text_encoder_2" in lora_loadable_components: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( @@ -506,15 +496,7 @@ def test_simple_inference_with_text_lora_fused(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) pipe.fuse_lora() # Fusing should still keep the LoRA layers @@ -546,19 +528,7 @@ def test_simple_inference_with_text_lora_unloaded(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - - if self.has_two_text_encoders or self.has_three_text_encoders: - lora_loadable_components = self.pipeline_class._lora_loadable_modules - if "text_encoder_2" in lora_loadable_components: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) pipe.unload_lora_weights() # unloading should remove the LoRA layers @@ -593,18 +563,7 @@ def test_simple_inference_with_text_lora_save_load(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -655,8 +614,11 @@ def test_simple_inference_with_partial_text_lora(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder` # supports missing layers (PR#8324). state_dict = { @@ -699,7 +661,7 @@ def test_simple_inference_save_pretrained(self): Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained """ for scheduler_cls in self.scheduler_classes: - components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -708,16 +670,7 @@ def test_simple_inference_save_pretrained(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - + pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] with tempfile.TemporaryDirectory() as tmpdirname: @@ -731,6 +684,10 @@ def test_simple_inference_save_pretrained(self): "Lora not correctly set in text encoder", ) + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: self.assertTrue( @@ -759,22 +716,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -820,22 +762,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( @@ -879,22 +806,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) @@ -932,22 +844,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) pipe.unload_lora_weights() # unloading should remove the LoRA layers @@ -983,22 +880,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused( pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -1111,20 +993,7 @@ def test_wrong_adapter_name_raises_error(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) with self.assertRaises(ValueError) as err_context: pipe.set_adapters("test") @@ -1143,20 +1012,7 @@ def test_multiple_wrong_adapter_name_raises_error(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0} logger = logging.get_logger("diffusers.loaders.lora_base") @@ -1804,20 +1660,7 @@ def test_simple_inference_with_dora(self): output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_dora_lora.shape == self.output_shape) - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - - if self.has_two_text_encoders or self.has_three_text_encoders: - lora_loadable_components = self.pipeline_class._lora_loadable_modules - if "text_encoder_2" in lora_loadable_components: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -1908,18 +1751,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - - if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True) @@ -2011,22 +1843,7 @@ def test_set_adapters_match_attention_kwargs(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) lora_scale = 0.5 attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} @@ -2211,22 +2028,7 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): pipe = pipe.to(torch_device, dtype=compute_dtype) pipe.set_progress_bar_config(disable=None) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) if storage_dtype is not None: denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) From 8de4eb9c24753a964b8faf5f784e124d325b1c9e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Apr 2025 13:33:17 +0530 Subject: [PATCH 2/4] update --- tests/lora/utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index edb419810948..01bd4b368759 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -260,24 +260,26 @@ def _get_modules_to_save(self, pipe, has_denoiser=False): return modules_to_save - def check_if_adapters_added_correctly(self, pipe, text_lora_config=None, denoiser_lora_config=None): + def check_if_adapters_added_correctly( + self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default" + ): if text_lora_config is not None: if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config) + pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name) self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" ) if denoiser_lora_config is not None: denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) + denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name) self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") else: denoiser = None if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config) + pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name) self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ) From f42ddf4c90664a3f4b8463c755e509bc162c0d9f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Apr 2025 14:05:12 +0530 Subject: [PATCH 3/4] updates --- tests/lora/utils.py | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 01bd4b368759..cb6b13f15403 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -658,12 +658,12 @@ def test_simple_inference_with_partial_text_lora(self): "Removing adapters should change the output", ) - def test_simple_inference_save_pretrained(self): + def test_simple_inference_save_pretrained_with_text_lora(self): """ Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained """ for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -672,7 +672,7 @@ def test_simple_inference_save_pretrained(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] with tempfile.TemporaryDirectory() as tmpdirname: @@ -681,14 +681,11 @@ def test_simple_inference_save_pretrained(self): pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) pipe_from_pretrained.to(torch_device) - self.assertTrue( - check_if_lora_correctly_set(pipe_from_pretrained.text_encoder), - "Lora not correctly set in text encoder", - ) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + self.assertTrue( + check_if_lora_correctly_set(pipe_from_pretrained.text_encoder), + "Lora not correctly set in text encoder", + ) if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: @@ -988,6 +985,8 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self): ) def test_wrong_adapter_name_raises_error(self): + adapter_name = "adapter-1" + scheduler_cls = self.scheduler_classes[0] components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) @@ -995,7 +994,9 @@ def test_wrong_adapter_name_raises_error(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.check_if_adapters_added_correctly( + pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name + ) with self.assertRaises(ValueError) as err_context: pipe.set_adapters("test") @@ -1003,10 +1004,11 @@ def test_wrong_adapter_name_raises_error(self): self.assertTrue("not in the list of present adapters" in str(err_context.exception)) # test this works. - pipe.set_adapters("adapter-1") + pipe.set_adapters(adapter_name) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] def test_multiple_wrong_adapter_name_raises_error(self): + adapter_name = "adapter-1" scheduler_cls = self.scheduler_classes[0] components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) @@ -1014,20 +1016,22 @@ def test_multiple_wrong_adapter_name_raises_error(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.check_if_adapters_added_correctly( + pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name + ) scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0} logger = logging.get_logger("diffusers.loaders.lora_base") logger.setLevel(30) with CaptureLogger(logger) as cap_logger: - pipe.set_adapters("adapter-1", adapter_weights=scale_with_wrong_components) + pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components) wrong_components = sorted(set(scale_with_wrong_components.keys())) msg = f"The following components in `adapter_weights` are not part of the pipeline: {wrong_components}. " self.assertTrue(msg in str(cap_logger.out)) # test this works. - pipe.set_adapters("adapter-1") + pipe.set_adapters(adapter_name) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] def test_simple_inference_with_text_denoiser_block_scale(self): From 65c4f6b03c4cbbd0cdfddcf42a03d93d93deb00f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Apr 2025 15:43:28 +0530 Subject: [PATCH 4/4] updates --- tests/lora/utils.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index cb6b13f15403..27fef495a484 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -616,25 +616,20 @@ def test_simple_inference_with_partial_text_lora(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) + pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + + state_dict = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder` - # supports missing layers (PR#8324). - state_dict = { - f"text_encoder.{module_name}": param - for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items() - if "text_model.encoder.layers.4" not in module_name - } + # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder` + # supports missing layers (PR#8324). + state_dict = { + f"text_encoder.{module_name}": param + for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items() + if "text_model.encoder.layers.4" not in module_name + } if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) state_dict.update( { f"text_encoder_2.{module_name}": param