Skip to content

Commit 1106c88

Browse files
committed
fix: lora unloading when using expanded Flux LoRAs.
1 parent 83da817 commit 1106c88

File tree

2 files changed

+62
-3
lines changed

2 files changed

+62
-3
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2278,15 +2278,15 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
22782278
super().unfuse_lora(components=components)
22792279

22802280
# We override this here account for `_transformer_norm_layers`.
2281-
def unload_lora_weights(self):
2281+
def unload_lora_weights(self, reset_to_overwrriten_params=False):
22822282
super().unload_lora_weights()
22832283

22842284
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
22852285
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
22862286
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
22872287
transformer._transformer_norm_layers = None
22882288

2289-
if getattr(transformer, "_overwritten_params", None) is not None:
2289+
if reset_to_overwrriten_params and getattr(transformer, "_overwritten_params", None) is not None:
22902290
overwritten_params = transformer._overwritten_params
22912291
module_names = set()
22922292

tests/lora/test_lora_layers_flux.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def test_lora_unload_with_parameter_expanded_shapes(self):
606606
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
607607
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
608608

609-
control_pipe.unload_lora_weights()
609+
control_pipe.unload_lora_weights(reset_to_overwrriten_params=True)
610610
self.assertTrue(
611611
control_pipe.transformer.config.in_channels == num_channels_without_control,
612612
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
@@ -624,6 +624,65 @@ def test_lora_unload_with_parameter_expanded_shapes(self):
624624
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
625625
self.assertTrue(pipe.transformer.config.in_channels == in_features)
626626

627+
def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
628+
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
629+
630+
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
631+
logger.setLevel(logging.DEBUG)
632+
633+
# Change the transformer config to mimic a real use case.
634+
num_channels_without_control = 4
635+
transformer = FluxTransformer2DModel.from_config(
636+
components["transformer"].config, in_channels=num_channels_without_control
637+
).to(torch_device)
638+
self.assertTrue(
639+
transformer.config.in_channels == num_channels_without_control,
640+
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
641+
)
642+
643+
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
644+
components["transformer"] = transformer
645+
pipe = FluxPipeline(**components)
646+
pipe = pipe.to(torch_device)
647+
pipe.set_progress_bar_config(disable=None)
648+
649+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
650+
control_image = inputs.pop("control_image")
651+
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
652+
653+
control_pipe = self.pipeline_class(**components)
654+
out_features, in_features = control_pipe.transformer.x_embedder.weight.shape
655+
rank = 4
656+
657+
dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
658+
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
659+
lora_state_dict = {
660+
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
661+
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
662+
}
663+
with CaptureLogger(logger) as cap_logger:
664+
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
665+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
666+
667+
inputs["control_image"] = control_image
668+
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
669+
670+
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
671+
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
672+
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
673+
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
674+
675+
control_pipe.unload_lora_weights(reset_to_overwrriten_params=False)
676+
self.assertTrue(
677+
control_pipe.transformer.config.in_channels == 2 * num_channels_without_control,
678+
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
679+
)
680+
no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
681+
682+
self.assertFalse(np.allclose(no_lora_out, lora_out, rtol=1e-4, atol=1e-4))
683+
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
684+
self.assertTrue(pipe.transformer.config.in_channels == in_features * 2)
685+
627686
@unittest.skip("Not supported in Flux.")
628687
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
629688
pass

0 commit comments

Comments
 (0)