Skip to content

Commit 343b2d2

Browse files
committed
Merge branch 'main' into improve-lora-warning-msg
2 parents ec44f9a + 1b202c5 commit 343b2d2

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

tests/lora/test_lora_layers_flux.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,72 @@ def test_load_regular_lora(self):
558558
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
559559
self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))
560560

561+
def test_lora_unload_with_parameter_expanded_shapes(self):
562+
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
563+
564+
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
565+
logger.setLevel(logging.DEBUG)
566+
567+
# Change the transformer config to mimic a real use case.
568+
num_channels_without_control = 4
569+
transformer = FluxTransformer2DModel.from_config(
570+
components["transformer"].config, in_channels=num_channels_without_control
571+
).to(torch_device)
572+
self.assertTrue(
573+
transformer.config.in_channels == num_channels_without_control,
574+
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
575+
)
576+
577+
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
578+
components["transformer"] = transformer
579+
pipe = FluxPipeline(**components)
580+
pipe = pipe.to(torch_device)
581+
pipe.set_progress_bar_config(disable=None)
582+
583+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
584+
control_image = inputs.pop("control_image")
585+
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
586+
587+
control_pipe = self.pipeline_class(**components)
588+
out_features, in_features = control_pipe.transformer.x_embedder.weight.shape
589+
rank = 4
590+
591+
dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
592+
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
593+
lora_state_dict = {
594+
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
595+
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
596+
}
597+
with CaptureLogger(logger) as cap_logger:
598+
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
599+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
600+
601+
inputs["control_image"] = control_image
602+
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
603+
604+
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
605+
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
606+
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
607+
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
608+
609+
control_pipe.unload_lora_weights()
610+
self.assertTrue(
611+
control_pipe.transformer.config.in_channels == num_channels_without_control,
612+
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
613+
)
614+
loaded_pipe = FluxPipeline.from_pipe(control_pipe)
615+
self.assertTrue(
616+
loaded_pipe.transformer.config.in_channels == num_channels_without_control,
617+
f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}",
618+
)
619+
inputs.pop("control_image")
620+
unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0]
621+
622+
self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4))
623+
self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4))
624+
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
625+
self.assertTrue(pipe.transformer.config.in_channels == in_features)
626+
561627
@unittest.skip("Not supported in Flux.")
562628
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
563629
pass

0 commit comments

Comments
 (0)