diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 4644ee81d48f..32d7c773d2b0 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -465,7 +465,7 @@ class LoraBaseMixin: """Utility class for handling LoRAs.""" _lora_loadable_modules = [] - num_fused_loras = 0 + _merged_adapters = set() def load_lora_weights(self, **kwargs): raise NotImplementedError("`load_lora_weights()` is not implemented.") @@ -592,6 +592,9 @@ def fuse_lora( if len(components) == 0: raise ValueError("`components` cannot be an empty list.") + # Need to retrieve the names as `adapter_names` can be None. So we cannot directly use it + # in `self._merged_adapters = self._merged_adapters | merged_adapter_names`. + merged_adapter_names = set() for fuse_component in components: if fuse_component not in self._lora_loadable_modules: raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.") @@ -601,13 +604,19 @@ def fuse_lora( # check if diffusers model if issubclass(model.__class__, ModelMixin): model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + merged_adapter_names.update(set(module.merged_adapters)) # handle transformers models. if issubclass(model.__class__, PreTrainedModel): fuse_text_encoder_lora( model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + merged_adapter_names.update(set(module.merged_adapters)) - self.num_fused_loras += 1 + self._merged_adapters = self._merged_adapters | merged_adapter_names def unfuse_lora(self, components: List[str] = [], **kwargs): r""" @@ -661,9 +670,18 @@ def unfuse_lora(self, components: List[str] = [], **kwargs): if issubclass(model.__class__, (ModelMixin, PreTrainedModel)): for module in model.modules(): if isinstance(module, BaseTunerLayer): + for adapter in set(module.merged_adapters): + if adapter and adapter in self._merged_adapters: + self._merged_adapters = self._merged_adapters - {adapter} module.unmerge() - self.num_fused_loras -= 1 + @property + def num_fused_loras(self): + return len(self._merged_adapters) + + @property + def fused_loras(self): + return self._merged_adapters def set_adapters( self, diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index dc2695452c2f..26dcdb1f4f41 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -124,6 +124,9 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_denoiser_lora_unfused(self): super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + def test_lora_scale_kwargs_match_fusion(self): + super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3) + @unittest.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_block_scale(self): pass diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index 0a31f214a38c..22d3ecbde8f4 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -117,6 +117,40 @@ def tearDown(self): def test_multiple_wrong_adapter_name_raises_error(self): super().test_multiple_wrong_adapter_name_raises_error() + def test_simple_inference_with_text_denoiser_lora_unfused(self): + if torch.cuda.is_available(): + expected_atol = 9e-2 + expected_rtol = 9e-2 + else: + expected_atol = 1e-3 + expected_rtol = 1e-3 + + super().test_simple_inference_with_text_denoiser_lora_unfused( + expected_atol=expected_atol, expected_rtol=expected_rtol + ) + + def test_simple_inference_with_text_lora_denoiser_fused_multi(self): + if torch.cuda.is_available(): + expected_atol = 9e-2 + expected_rtol = 9e-2 + else: + expected_atol = 1e-3 + expected_rtol = 1e-3 + + super().test_simple_inference_with_text_lora_denoiser_fused_multi( + expected_atol=expected_atol, expected_rtol=expected_rtol + ) + + def test_lora_scale_kwargs_match_fusion(self): + if torch.cuda.is_available(): + expected_atol = 9e-2 + expected_rtol = 9e-2 + else: + expected_atol = 1e-3 + expected_rtol = 1e-3 + + super().test_lora_scale_kwargs_match_fusion(expected_atol=expected_atol, expected_rtol=expected_rtol) + @slow @nightly diff --git a/tests/lora/utils.py b/tests/lora/utils.py index cc760ea84cd0..a118c150644a 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -80,6 +80,18 @@ def initialize_dummy_state_dict(state_dict): POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"] +def determine_attention_kwargs_name(pipeline_class): + call_signature_keys = inspect.signature(pipeline_class.__call__).parameters.keys() + + # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release + for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES: + if possible_attention_kwargs in call_signature_keys: + attention_kwargs_name = possible_attention_kwargs + break + assert attention_kwargs_name is not None + return attention_kwargs_name + + @require_peft_backend class PeftLoraLoaderMixinTests: pipeline_class = None @@ -442,14 +454,7 @@ def test_simple_inference_with_text_lora_and_scale(self): Tests a simple inference with lora attached on the text encoder + scale argument and makes sure it works as expected """ - call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() - - # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release - for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES: - if possible_attention_kwargs in call_signature_keys: - attention_kwargs_name = possible_attention_kwargs - break - assert attention_kwargs_name is not None + attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) for scheduler_cls in self.scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) @@ -740,12 +745,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): Tests a simple inference with lora attached on the text encoder + Unet + scale argument and makes sure it works as expected """ - call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() - for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES: - if possible_attention_kwargs in call_signature_keys: - attention_kwargs_name = possible_attention_kwargs - break - assert attention_kwargs_name is not None + attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -878,9 +878,11 @@ def test_simple_inference_with_text_denoiser_lora_unfused( pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) + self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) + self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] # unloading should remove the LoRA layers @@ -1608,26 +1610,21 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi( self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" ) + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") - - # Attach a second adapter - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - - denoiser.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + denoiser.add_adapter(denoiser_lora_config, "adapter-2") if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules if "text_encoder_2" in lora_loadable_components: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ) + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") # set them to multi-adapter inference mode pipe.set_adapters(["adapter-1", "adapter-2"]) @@ -1637,6 +1634,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi( outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"]) + self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") # Fusing should still keep the LoRA layers so outpout should remain the same outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -1647,9 +1645,23 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi( ) pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) + self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers") + + self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers") + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" + ) + pipe.fuse_lora( components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"] ) + self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") # Fusing should still keep the LoRA layers output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -1657,6 +1669,63 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi( np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol), "Fused lora should not change the output", ) + pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) + self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") + + def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3): + attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) + + for lora_scale in [1.0, 0.8]: + for scheduler_cls in self.scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(output_no_lora.shape == self.output_shape) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + + if self.has_two_text_encoders or self.has_three_text_encoders: + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), + "Lora not correctly set in text encoder 2", + ) + + pipe.set_adapters(["adapter-1"]) + attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} + outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + + pipe.fuse_lora( + components=self.pipeline_class._lora_loadable_modules, + adapter_names=["adapter-1"], + lora_scale=lora_scale, + ) + self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") + + outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertTrue( + np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), + "Fused lora should not change the output", + ) + self.assertFalse( + np.allclose(output_no_lora, outputs_lora_1, atol=expected_atol, rtol=expected_rtol), + "LoRA should change the output", + ) @require_peft_version_greater(peft_version="0.9.0") def test_simple_inference_with_dora(self): @@ -1838,12 +1907,7 @@ def test_logs_info_when_no_lora_keys_found(self): def test_set_adapters_match_attention_kwargs(self): """Test to check if outputs after `set_adapters()` and attention kwargs match.""" - call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() - for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES: - if possible_attention_kwargs in call_signature_keys: - attention_kwargs_name = possible_attention_kwargs - break - assert attention_kwargs_name is not None + attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)