Skip to content

Commit f42ddf4

Browse files
committed
updates
1 parent 8de4eb9 commit f42ddf4

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

tests/lora/utils.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -658,12 +658,12 @@ def test_simple_inference_with_partial_text_lora(self):
658658
"Removing adapters should change the output",
659659
)
660660

661-
def test_simple_inference_save_pretrained(self):
661+
def test_simple_inference_save_pretrained_with_text_lora(self):
662662
"""
663663
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
664664
"""
665665
for scheduler_cls in self.scheduler_classes:
666-
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
666+
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
667667
pipe = self.pipeline_class(**components)
668668
pipe = pipe.to(torch_device)
669669
pipe.set_progress_bar_config(disable=None)
@@ -672,7 +672,7 @@ def test_simple_inference_save_pretrained(self):
672672
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
673673
self.assertTrue(output_no_lora.shape == self.output_shape)
674674

675-
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
675+
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
676676
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
677677

678678
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -681,14 +681,11 @@ def test_simple_inference_save_pretrained(self):
681681
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
682682
pipe_from_pretrained.to(torch_device)
683683

684-
self.assertTrue(
685-
check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
686-
"Lora not correctly set in text encoder",
687-
)
688-
689-
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
690-
denoiser.add_adapter(denoiser_lora_config)
691-
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
684+
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
685+
self.assertTrue(
686+
check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
687+
"Lora not correctly set in text encoder",
688+
)
692689

693690
if self.has_two_text_encoders or self.has_three_text_encoders:
694691
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
@@ -988,46 +985,53 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
988985
)
989986

990987
def test_wrong_adapter_name_raises_error(self):
988+
adapter_name = "adapter-1"
989+
991990
scheduler_cls = self.scheduler_classes[0]
992991
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
993992
pipe = self.pipeline_class(**components)
994993
pipe = pipe.to(torch_device)
995994
pipe.set_progress_bar_config(disable=None)
996995
_, _, inputs = self.get_dummy_inputs(with_generator=False)
997996

998-
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
997+
pipe, _ = self.check_if_adapters_added_correctly(
998+
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
999+
)
9991000

10001001
with self.assertRaises(ValueError) as err_context:
10011002
pipe.set_adapters("test")
10021003

10031004
self.assertTrue("not in the list of present adapters" in str(err_context.exception))
10041005

10051006
# test this works.
1006-
pipe.set_adapters("adapter-1")
1007+
pipe.set_adapters(adapter_name)
10071008
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
10081009

10091010
def test_multiple_wrong_adapter_name_raises_error(self):
1011+
adapter_name = "adapter-1"
10101012
scheduler_cls = self.scheduler_classes[0]
10111013
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
10121014
pipe = self.pipeline_class(**components)
10131015
pipe = pipe.to(torch_device)
10141016
pipe.set_progress_bar_config(disable=None)
10151017
_, _, inputs = self.get_dummy_inputs(with_generator=False)
10161018

1017-
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
1019+
pipe, _ = self.check_if_adapters_added_correctly(
1020+
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
1021+
)
10181022

10191023
scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0}
10201024
logger = logging.get_logger("diffusers.loaders.lora_base")
10211025
logger.setLevel(30)
10221026
with CaptureLogger(logger) as cap_logger:
1023-
pipe.set_adapters("adapter-1", adapter_weights=scale_with_wrong_components)
1027+
pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components)
10241028

10251029
wrong_components = sorted(set(scale_with_wrong_components.keys()))
10261030
msg = f"The following components in `adapter_weights` are not part of the pipeline: {wrong_components}. "
10271031
self.assertTrue(msg in str(cap_logger.out))
10281032

10291033
# test this works.
1030-
pipe.set_adapters("adapter-1")
1034+
pipe.set_adapters(adapter_name)
10311035
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
10321036

10331037
def test_simple_inference_with_text_denoiser_block_scale(self):

0 commit comments

Comments
 (0)