Skip to content

Commit f20c022

Browse files
committed
make set_adapters() method more robust.
1 parent b52684c commit f20c022

File tree

3 files changed

+129
-0
lines changed

3 files changed

+129
-0
lines changed

better_set_adapters.patch

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py
2+
index 89bb498a3..8ecd4d459 100644
3+
--- a/src/diffusers/loaders/lora_base.py
4+
+++ b/src/diffusers/loaders/lora_base.py
5+
@@ -532,6 +532,11 @@ class LoraBaseMixin:
6+
)
7+
8+
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
9+
+ current_adapter_names = {adapter for _, adapter_list in list_adapters.items() for adapter in adapter_list}
10+
+ for input_adapter_name in adapter_names:
11+
+ if input_adapter_name not in current_adapter_names:
12+
+ raise ValueError(f"Adapter name {input_adapter_name} not in the list of present adapters: {current_adapter_names}.")
13+
+
14+
all_adapters = {
15+
adapter for adapters in list_adapters.values() for adapter in adapters
16+
} # eg ["adapter1", "adapter2"]
17+
diff --git a/tests/lora/utils.py b/tests/lora/utils.py
18+
index 939b749c2..163260709 100644
19+
--- a/tests/lora/utils.py
20+
+++ b/tests/lora/utils.py
21+
@@ -929,12 +929,16 @@ class PeftLoraLoaderMixinTests:
22+
23+
pipe.set_adapters("adapter-1")
24+
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
25+
+ self.assertFalse(np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3), "Adapter outputs should be different.")
26+
+
27+
28+
pipe.set_adapters("adapter-2")
29+
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
30+
+ self.assertFalse(np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3), "Adapter outputs should be different.")
31+
32+
pipe.set_adapters(["adapter-1", "adapter-2"])
33+
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
34+
+ self.assertFalse(np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3), "Adapter outputs should be different.")
35+
36+
# Fuse and unfuse should lead to the same results
37+
self.assertFalse(
38+
@@ -960,6 +964,40 @@ class PeftLoraLoaderMixinTests:
39+
"output with no lora and output with lora disabled should give same results",
40+
)
41+
42+
+ def test_wrong_adapter_name_raises_error(self):
43+
+ scheduler_cls = self.scheduler_classes[0]
44+
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
45+
+ pipe = self.pipeline_class(**components)
46+
+ pipe = pipe.to(torch_device)
47+
+ pipe.set_progress_bar_config(disable=None)
48+
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
49+
+
50+
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
51+
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
52+
+ self.assertTrue(
53+
+ check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
54+
+ )
55+
+
56+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
57+
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
58+
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
59+
+
60+
+ if self.has_two_text_encoders or self.has_three_text_encoders:
61+
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
62+
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
63+
+ self.assertTrue(
64+
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
65+
+ )
66+
+
67+
+ with self.assertRaises(ValueError) as err_context:
68+
+ pipe.set_adapters("test")
69+
+
70+
+ self.assertTrue("not in the list of present adapters" in str(err_context.exception))
71+
+
72+
+ # test this works.
73+
+ pipe.set_adapters("adapter-1")
74+
+ _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
75+
+
76+
def test_simple_inference_with_text_denoiser_block_scale(self):
77+
"""
78+
Tests a simple inference with lora attached to text encoder and unet, attaches

src/diffusers/loaders/lora_base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,13 @@ def set_adapters(
532532
)
533533

534534
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
535+
current_adapter_names = {adapter for _, adapter_list in list_adapters.items() for adapter in adapter_list}
536+
for input_adapter_name in adapter_names:
537+
if input_adapter_name not in current_adapter_names:
538+
raise ValueError(
539+
f"Adapter name {input_adapter_name} not in the list of present adapters: {current_adapter_names}."
540+
)
541+
535542
all_adapters = {
536543
adapter for adapters in list_adapters.values() for adapter in adapters
537544
} # eg ["adapter1", "adapter2"]

tests/lora/utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,12 +929,24 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
929929

930930
pipe.set_adapters("adapter-1")
931931
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
932+
self.assertFalse(
933+
np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3),
934+
"Adapter outputs should be different.",
935+
)
932936

933937
pipe.set_adapters("adapter-2")
934938
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
939+
self.assertFalse(
940+
np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3),
941+
"Adapter outputs should be different.",
942+
)
935943

936944
pipe.set_adapters(["adapter-1", "adapter-2"])
937945
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
946+
self.assertFalse(
947+
np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3),
948+
"Adapter outputs should be different.",
949+
)
938950

939951
# Fuse and unfuse should lead to the same results
940952
self.assertFalse(
@@ -960,6 +972,38 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
960972
"output with no lora and output with lora disabled should give same results",
961973
)
962974

975+
def test_wrong_adapter_name_raises_error(self):
976+
scheduler_cls = self.scheduler_classes[0]
977+
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
978+
pipe = self.pipeline_class(**components)
979+
pipe = pipe.to(torch_device)
980+
pipe.set_progress_bar_config(disable=None)
981+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
982+
983+
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
984+
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
985+
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
986+
987+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
988+
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
989+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
990+
991+
if self.has_two_text_encoders or self.has_three_text_encoders:
992+
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
993+
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
994+
self.assertTrue(
995+
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
996+
)
997+
998+
with self.assertRaises(ValueError) as err_context:
999+
pipe.set_adapters("test")
1000+
1001+
self.assertTrue("not in the list of present adapters" in str(err_context.exception))
1002+
1003+
# test this works.
1004+
pipe.set_adapters("adapter-1")
1005+
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
1006+
9631007
def test_simple_inference_with_text_denoiser_block_scale(self):
9641008
"""
9651009
Tests a simple inference with lora attached to text encoder and unet, attaches

0 commit comments

Comments
 (0)