From 1106c884f0aaf4012e891639e8986ed3722a2930 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Dec 2024 18:56:24 +0530 Subject: [PATCH 1/3] fix: lora unloading when using expanded Flux LoRAs. --- src/diffusers/loaders/lora_pipeline.py | 4 +- tests/lora/test_lora_layers_flux.py | 61 +++++++++++++++++++++++++- 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 351295e938ff..fe44b1de6533 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2278,7 +2278,7 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * super().unfuse_lora(components=components) # We override this here account for `_transformer_norm_layers`. - def unload_lora_weights(self): + def unload_lora_weights(self, reset_to_overwrriten_params=False): super().unload_lora_weights() transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer @@ -2286,7 +2286,7 @@ def unload_lora_weights(self): transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) transformer._transformer_norm_layers = None - if getattr(transformer, "_overwritten_params", None) is not None: + if reset_to_overwrriten_params and getattr(transformer, "_overwritten_params", None) is not None: overwritten_params = transformer._overwritten_params module_names = set() diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 0861160de6aa..49f277ba6887 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -606,7 +606,7 @@ def test_lora_unload_with_parameter_expanded_shapes(self): self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) - control_pipe.unload_lora_weights() + control_pipe.unload_lora_weights(reset_to_overwrriten_params=True) self.assertTrue( control_pipe.transformer.config.in_channels == num_channels_without_control, f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}", @@ -624,6 +624,65 @@ def test_lora_unload_with_parameter_expanded_shapes(self): self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features) self.assertTrue(pipe.transformer.config.in_channels == in_features) + def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self): + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.DEBUG) + + # Change the transformer config to mimic a real use case. + num_channels_without_control = 4 + transformer = FluxTransformer2DModel.from_config( + components["transformer"].config, in_channels=num_channels_without_control + ).to(torch_device) + self.assertTrue( + transformer.config.in_channels == num_channels_without_control, + f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}", + ) + + # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`. + components["transformer"] = transformer + pipe = FluxPipeline(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + control_image = inputs.pop("control_image") + original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + + control_pipe = self.pipeline_class(**components) + out_features, in_features = control_pipe.transformer.x_embedder.weight.shape + rank = 4 + + dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) + dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, + "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, + } + with CaptureLogger(logger) as cap_logger: + control_pipe.load_lora_weights(lora_state_dict, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + inputs["control_image"] = control_image + lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) + self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) + self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) + self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) + + control_pipe.unload_lora_weights(reset_to_overwrriten_params=False) + self.assertTrue( + control_pipe.transformer.config.in_channels == 2 * num_channels_without_control, + f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}", + ) + no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertFalse(np.allclose(no_lora_out, lora_out, rtol=1e-4, atol=1e-4)) + self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2) + self.assertTrue(pipe.transformer.config.in_channels == in_features * 2) + @unittest.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass From 71851388f2b0ea07436131bb584115d67a8d2fde Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 30 Dec 2024 13:35:20 +0530 Subject: [PATCH 2/3] fix argument name. Co-authored-by: a-r-r-o-w --- src/diffusers/loaders/lora_pipeline.py | 4 ++-- tests/lora/test_lora_layers_flux.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index fe44b1de6533..e814b76d1ad6 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2278,7 +2278,7 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * super().unfuse_lora(components=components) # We override this here account for `_transformer_norm_layers`. - def unload_lora_weights(self, reset_to_overwrriten_params=False): + def unload_lora_weights(self, reset_to_overwritten_params=False): super().unload_lora_weights() transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer @@ -2286,7 +2286,7 @@ def unload_lora_weights(self, reset_to_overwrriten_params=False): transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) transformer._transformer_norm_layers = None - if reset_to_overwrriten_params and getattr(transformer, "_overwritten_params", None) is not None: + if reset_to_overwritten_params and getattr(transformer, "_overwritten_params", None) is not None: overwritten_params = transformer._overwritten_params module_names = set() diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 49f277ba6887..4478584f103a 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -606,7 +606,7 @@ def test_lora_unload_with_parameter_expanded_shapes(self): self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) - control_pipe.unload_lora_weights(reset_to_overwrriten_params=True) + control_pipe.unload_lora_weights(reset_to_overwritten_params=True) self.assertTrue( control_pipe.transformer.config.in_channels == num_channels_without_control, f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}", @@ -672,7 +672,7 @@ def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self): self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) - control_pipe.unload_lora_weights(reset_to_overwrriten_params=False) + control_pipe.unload_lora_weights(reset_to_overwritten_params=False) self.assertTrue( control_pipe.transformer.config.in_channels == 2 * num_channels_without_control, f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}", From 4b370d2ba73ac40e0540aa3b143fcd4b7b149cbd Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 30 Dec 2024 13:53:03 +0530 Subject: [PATCH 3/3] docs. --- docs/source/en/api/pipelines/flux.md | 4 ++++ src/diffusers/loaders/lora_pipeline.py | 18 +++++++++++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index 080442efb0d1..ad89b267a2b7 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -305,6 +305,10 @@ image = control_pipe( image.save("output.png") ``` +## Note about `unload_lora_weights()` when using Flux LoRAs + +When unloading the Control LoRA weights, call `pipe.unload_lora_weights(reset_to_overwritten_params=True)` to reset the `pipe.transformer` completely back to its original form. The resultant pipeline can then be used with methods like [`DiffusionPipeline.from_pipe`]. More details about this argument are available in [this PR](https://github.com/huggingface/diffusers/pull/10397). + ## Running FP16 inference Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details. diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index e814b76d1ad6..5d1bcffd3422 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2277,8 +2277,24 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * super().unfuse_lora(components=components) - # We override this here account for `_transformer_norm_layers`. + # We override this here account for `_transformer_norm_layers` and `_overwritten_params`. def unload_lora_weights(self, reset_to_overwritten_params=False): + """ + Unloads the LoRA parameters. + + Args: + reset_to_overwritten_params (`bool`, defaults to `False`): Whether to reset the LoRA-loaded modules + to their original params. Refer to the [Flux + documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more. + + Examples: + + ```python + >>> # Assuming `pipeline` is already loaded with the LoRA parameters. + >>> pipeline.unload_lora_weights() + >>> ... + ``` + """ super().unload_lora_weights() transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer