From 689f172565fa137f5671328f83e3cc0374ff676a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Jun 2025 20:27:19 +0530 Subject: [PATCH 1/3] fix: lora unloading behvaiour --- src/diffusers/loaders/peft.py | 2 + tests/lora/utils.py | 69 +++++++++++++++++++++++------------ 2 files changed, 47 insertions(+), 24 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 343623071340..039301e0f814 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -686,6 +686,8 @@ def unload_lora(self): recurse_remove_peft_layers(self) if hasattr(self, "peft_config"): del self.peft_config + if hasattr(self, "_hf_peft_config_loaded"): + self._hf_peft_config_loaded = None def disable_lora(self): """ diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 93dc4a2c37e3..1c87d41a235f 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -290,9 +290,7 @@ def _get_modules_to_save(self, pipe, has_denoiser=False): return modules_to_save - def check_if_adapters_added_correctly( - self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default" - ): + def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"): if text_lora_config is not None: if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name) @@ -344,7 +342,7 @@ def test_simple_inference_with_text_lora(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( @@ -427,7 +425,7 @@ def test_low_cpu_mem_usage_with_loading(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -483,7 +481,7 @@ def test_simple_inference_with_text_lora_and_scale(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( @@ -521,7 +519,7 @@ def test_simple_inference_with_text_lora_fused(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe.fuse_lora() # Fusing should still keep the LoRA layers @@ -553,7 +551,7 @@ def test_simple_inference_with_text_lora_unloaded(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe.unload_lora_weights() # unloading should remove the LoRA layers @@ -588,7 +586,7 @@ def test_simple_inference_with_text_lora_save_load(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -639,7 +637,7 @@ def test_simple_inference_with_partial_text_lora(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) state_dict = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -690,7 +688,7 @@ def test_simple_inference_save_pretrained_with_text_lora(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] with tempfile.TemporaryDirectory() as tmpdirname: @@ -733,7 +731,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -774,7 +772,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( @@ -818,7 +816,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) @@ -856,7 +854,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe.unload_lora_weights() # unloading should remove the LoRA layers @@ -892,7 +890,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused( pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") @@ -1009,7 +1007,7 @@ def test_wrong_adapter_name_raises_error(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, _ = self.check_if_adapters_added_correctly( + pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name ) @@ -1031,7 +1029,7 @@ def test_multiple_wrong_adapter_name_raises_error(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, _ = self.check_if_adapters_added_correctly( + pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name ) @@ -1758,7 +1756,7 @@ def test_simple_inference_with_dora(self): output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_dora_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -1849,7 +1847,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True) @@ -1936,7 +1934,7 @@ def test_set_adapters_match_attention_kwargs(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) lora_scale = 0.5 attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} @@ -2118,7 +2116,7 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): pipe = pipe.to(torch_device, dtype=compute_dtype) pipe.set_progress_bar_config(disable=None) - pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) if storage_dtype is not None: denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) @@ -2236,7 +2234,7 @@ def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha): ) pipe = self.pipeline_class(**components) - pipe, _ = self.check_if_adapters_added_correctly( + pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) @@ -2289,7 +2287,7 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly( + pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -2308,6 +2306,29 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match." ) + def test_lora_unload_add_adapter(self, lora_alpha): + """Tests if `unload_lora_weights()` -> `add_adapter()` works.""" + scheduler_cls = self.scheduler_classes[0] + components, text_lora_config, denoiser_lora_config = self.get_dummy_components( + scheduler_cls, lora_alpha=lora_alpha + ) + pipe = self.pipeline_class(**components).to(torch_device) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline( + pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config + ) + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + + # unload and then add. + pipe.unload_lora_weights() + pipe, _ = self.add_adapters_to_pipeline( + pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config + ) + + output_lora_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(np.allclose(output_lora, output_lora_2, atol=1e-3, rtol=1e-3), "Lora outputs should match.") + def test_inference_load_delete_load_adapters(self): "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." for scheduler_cls in self.scheduler_classes: From b83a7ce36e20a9fb8c03dd83acfe64155f34bb9f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Jun 2025 20:40:10 +0530 Subject: [PATCH 2/3] fix --- tests/lora/utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 1c87d41a235f..91d3c5f2c32e 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -2306,12 +2306,10 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match." ) - def test_lora_unload_add_adapter(self, lora_alpha): + def test_lora_unload_add_adapter(self): """Tests if `unload_lora_weights()` -> `add_adapter()` works.""" scheduler_cls = self.scheduler_classes[0] - components, text_lora_config, denoiser_lora_config = self.get_dummy_components( - scheduler_cls, lora_alpha=lora_alpha - ) + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components).to(torch_device) _, _, inputs = self.get_dummy_inputs(with_generator=False) From c228e8ee6a0b27e80006deb56879235a018efca0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Jun 2025 22:20:40 +0530 Subject: [PATCH 3/3] update --- tests/lora/utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 91d3c5f2c32e..7c9fefaa539a 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -2316,16 +2316,14 @@ def test_lora_unload_add_adapter(self): pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) - output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + _ = pipe(**inputs, generator=torch.manual_seed(0))[0] # unload and then add. pipe.unload_lora_weights() pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) - - output_lora_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(np.allclose(output_lora, output_lora_2, atol=1e-3, rtol=1e-3), "Lora outputs should match.") + _ = pipe(**inputs, generator=torch.manual_seed(0))[0] def test_inference_load_delete_load_adapters(self): "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."