diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 1445394b8784..01040b06927b 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2337,12 +2337,19 @@ def _maybe_expand_transformer_param_shape_or_error_( f"this please open an issue at https://github.com/huggingface/diffusers/issues." ) - logger.debug( + debug_message = ( f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA ' f"checkpoint contains higher number of features than expected. The number of input_features will be " - f"expanded from {module_in_features} to {in_features}, and the number of output features will be " - f"expanded from {module_out_features} to {out_features}." + f"expanded from {module_in_features} to {in_features}" ) + if module_out_features != out_features: + debug_message += ( + ", and the number of output features will be " + f"expanded from {module_out_features} to {out_features}." + ) + else: + debug_message += "." + logger.debug(debug_message) has_param_with_shape_update = True parent_module_name, _, current_module_name = name.rpartition(".") diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 32df644b758d..3851ff32ddfa 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -205,6 +205,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans weights. """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + from peft.tuners.tuners_utils import BaseTunerLayer cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) @@ -316,8 +317,22 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans if is_peft_version(">=", "0.13.1"): peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) - incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) + # To handle scenarios where we cannot successfully set state dict. If it's unsucessful, + # we should also delete the `peft_config` associated to the `adapter_name`. + try: + inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) + incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) + except RuntimeError as e: + for module in self.modules(): + if isinstance(module, BaseTunerLayer): + active_adapters = module.active_adapters + for active_adapter in active_adapters: + if adapter_name in active_adapter: + module.delete_adapter(adapter_name) + + self.peft_config.pop(adapter_name) + logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}") + raise warn_msg = "" if incompatible_keys is not None: diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 8142085f981c..b28fdde91574 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -430,6 +430,122 @@ def test_correct_lora_configs_with_different_ranks(self): self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + def test_lora_expanding_shape_with_normal_lora_raises_error(self): + # TODO: This test checks if an error is raised when a lora expands shapes (like control loras) but + # another lora with correct shapes is loaded. This is not supported at the moment and should raise an error. + # When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180 + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + + # Change the transformer config to mimic a real use case. + num_channels_without_control = 4 + transformer = FluxTransformer2DModel.from_config( + components["transformer"].config, in_channels=num_channels_without_control + ).to(torch_device) + components["transformer"] = transformer + + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.DEBUG) + + out_features, in_features = pipe.transformer.x_embedder.weight.shape + rank = 4 + + shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) + shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight, + "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight, + } + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(lora_state_dict, "adapter-1") + + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + self.assertTrue(pipe.get_active_adapters() == ["adapter-1"]) + self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) + self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) + self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) + normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, + "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, + } + + # The first lora expanded the input features of x_embedder. Here, we are trying to load a lora with the correct + # input features before expansion. This should raise an error about the weight shapes being incompatible. + self.assertRaisesRegex( + RuntimeError, + "size mismatch for x_embedder.lora_A.adapter-2.weight", + pipe.load_lora_weights, + lora_state_dict, + "adapter-2", + ) + # We should have `adapter-1` as the only adapter. + self.assertTrue(pipe.get_active_adapters() == ["adapter-1"]) + + # Check if the output is the same after lora loading error + lora_output_after_error = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(np.allclose(lora_output, lora_output_after_error, atol=1e-3, rtol=1e-3)) + + # Test the opposite case where the first lora has the correct input features and the second lora has expanded input features. + # This should raise a runtime error on input shapes being incompatible. But it doesn't. This is because PEFT renames the + # original layers as `base_layer` and the lora layers with the adapter names. This makes our logic to check if a lora + # weight is compatible with the current model inadequate. This should be addressed when attempting support for + # https://github.com/huggingface/diffusers/issues/10180 (TODO) + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + # Change the transformer config to mimic a real use case. + num_channels_without_control = 4 + transformer = FluxTransformer2DModel.from_config( + components["transformer"].config, in_channels=num_channels_without_control + ).to(torch_device) + components["transformer"] = transformer + + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.DEBUG) + + out_features, in_features = pipe.transformer.x_embedder.weight.shape + rank = 4 + + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, + "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, + } + + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(lora_state_dict, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features) + self.assertTrue(pipe.transformer.config.in_channels == in_features) + self.assertFalse(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) + + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight, + "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight, + } + + # We should check for input shapes being incompatible here. But because above mentioned issue is + # not a supported use case, and because of the PEFT renaming, we will currently have a shape + # mismatch error. + self.assertRaisesRegex( + RuntimeError, + "size mismatch for x_embedder.lora_A.adapter-2.weight", + pipe.load_lora_weights, + lora_state_dict, + "adapter-2", + ) + @unittest.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass