Skip to content

Commit 4eef79e

Browse files
sayakpaula-r-r-o-w
andcommitted
tests
Co-authored-by: a-r-r-o-w <[email protected]>
1 parent 5ef79f3 commit 4eef79e

File tree

2 files changed

+109
-24
lines changed

2 files changed

+109
-24
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2407,7 +2407,7 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
24072407
expanded_module_names.add(k)
24082408
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
24092409
raise NotImplementedError(
2410-
"We currently don't support loading LoRAs for this use case. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
2410+
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
24112411
)
24122412

24132413
if expanded_module_names:

tests/lora/test_lora_layers_flux.py

Lines changed: 108 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -430,10 +430,10 @@ def test_correct_lora_configs_with_different_ranks(self):
430430
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
431431
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
432432

433-
def test_lora_expanding_shape_with_normal_lora_raises_error(self):
434-
# TODO: This test checks if an error is raised when a lora expands shapes (like control loras) but
435-
# another lora with correct shapes is loaded. This is not supported at the moment and should raise an error.
436-
# When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180
433+
def test_lora_expanding_shape_with_normal_lora(self):
434+
# This test checks if it works when a lora with expanded shapes (like control loras) but
435+
# another lora with correct shapes is loaded. The opposite direction isn't supported and is
436+
# tested with it.
437437
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
438438

439439
# Change the transformer config to mimic a real use case.
@@ -478,21 +478,16 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self):
478478
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
479479
}
480480

481-
# The first lora expanded the input features of x_embedder. Here, we are trying to load a lora with the correct
482-
# input features before expansion. This should raise an error about the weight shapes being incompatible.
483-
self.assertRaisesRegex(
484-
RuntimeError,
485-
"size mismatch for x_embedder.lora_A.adapter-2.weight",
486-
pipe.load_lora_weights,
487-
lora_state_dict,
488-
"adapter-2",
489-
)
490-
# We should have `adapter-1` as the only adapter.
491-
self.assertTrue(pipe.get_active_adapters() == ["adapter-1"])
481+
with CaptureLogger(logger) as cap_logger:
482+
pipe.load_lora_weights(lora_state_dict, "adapter-2")
483+
484+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
485+
self.assertTrue(pipe.get_active_adapters() == ["adapter-2"])
486+
487+
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
492488

493-
# Check if the output is the same after lora loading error
494-
lora_output_after_error = pipe(**inputs, generator=torch.manual_seed(0))[0]
495-
self.assertTrue(np.allclose(lora_output, lora_output_after_error, atol=1e-3, rtol=1e-3))
489+
self.assertTrue("Found some LoRA modules for which the weights were zero-padded" in cap_logger.out)
490+
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
496491

497492
# Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
498493
# This should raise a runtime error on input shapes being incompatible. But it doesn't. This is because PEFT renames the
@@ -524,8 +519,8 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self):
524519

525520
with CaptureLogger(logger) as cap_logger:
526521
pipe.load_lora_weights(lora_state_dict, "adapter-1")
527-
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
528522

523+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
529524
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
530525
self.assertTrue(pipe.transformer.config.in_channels == in_features)
531526
self.assertFalse(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
@@ -535,17 +530,107 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self):
535530
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
536531
}
537532

538-
# We should check for input shapes being incompatible here. But because above mentioned issue is
539-
# not a supported use case, and because of the PEFT renaming, we will currently have a shape
540-
# mismatch error.
533+
# We should check for input shapes being incompatible here.
541534
self.assertRaisesRegex(
542535
RuntimeError,
543-
"size mismatch for x_embedder.lora_A.adapter-2.weight",
536+
"x_embedder.lora_A.weight",
544537
pipe.load_lora_weights,
545538
lora_state_dict,
546539
"adapter-2",
547540
)
548541

542+
def test_fuse_expanded_lora_with_regular_lora(self):
543+
# This test checks if it works when a lora with expanded shapes (like control loras) but
544+
# another lora with correct shapes is loaded. The opposite direction isn't supported and is
545+
# tested with it.
546+
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
547+
548+
# Change the transformer config to mimic a real use case.
549+
num_channels_without_control = 4
550+
transformer = FluxTransformer2DModel.from_config(
551+
components["transformer"].config, in_channels=num_channels_without_control
552+
).to(torch_device)
553+
components["transformer"] = transformer
554+
555+
pipe = self.pipeline_class(**components)
556+
pipe = pipe.to(torch_device)
557+
pipe.set_progress_bar_config(disable=None)
558+
559+
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
560+
logger.setLevel(logging.DEBUG)
561+
562+
out_features, in_features = pipe.transformer.x_embedder.weight.shape
563+
rank = 4
564+
565+
shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
566+
shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False)
567+
lora_state_dict = {
568+
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
569+
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
570+
}
571+
pipe.load_lora_weights(lora_state_dict, "adapter-1")
572+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
573+
574+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
575+
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
576+
577+
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
578+
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
579+
lora_state_dict = {
580+
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
581+
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
582+
}
583+
584+
pipe.load_lora_weights(lora_state_dict, "adapter-2")
585+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
586+
587+
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
588+
589+
pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0])
590+
lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
591+
592+
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
593+
self.assertFalse(np.allclose(lora_output, lora_output_3, atol=1e-3, rtol=1e-3))
594+
self.assertFalse(np.allclose(lora_output_2, lora_output_3, atol=1e-3, rtol=1e-3))
595+
596+
pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"])
597+
lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0]
598+
self.assertTrue(np.allclose(lora_output_3, lora_output_4, atol=1e-3, rtol=1e-3))
599+
600+
def test_load_regular_lora(self):
601+
# This test checks if a regular lora (think of one trained Flux.1 Dev for example) can be loaded
602+
# into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those
603+
# transformers include Flux Fill, Flux Control, etc.
604+
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
605+
pipe = self.pipeline_class(**components)
606+
pipe = pipe.to(torch_device)
607+
pipe.set_progress_bar_config(disable=None)
608+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
609+
610+
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
611+
612+
out_features, in_features = pipe.transformer.x_embedder.weight.shape
613+
rank = 4
614+
in_features = in_features // 2 # to mimic the Flux.1-Dev LoRA.
615+
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
616+
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
617+
lora_state_dict = {
618+
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
619+
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
620+
}
621+
622+
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
623+
logger.setLevel(logging.INFO)
624+
with CaptureLogger(logger) as cap_logger:
625+
pipe.load_lora_weights(lora_state_dict, "adapter-1")
626+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
627+
628+
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
629+
630+
self.assertTrue("Found some LoRA modules for which the weights were zero-padded" in cap_logger.out)
631+
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
632+
self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))
633+
549634
@unittest.skip("Not supported in Flux.")
550635
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
551636
pass

0 commit comments

Comments
 (0)