From f20c022ebd0b9108c307afecc250a20bf7ea4084 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 26 Sep 2024 11:35:44 +0530 Subject: [PATCH 1/4] make set_adapters() method more robust. --- better_set_adapters.patch | 78 ++++++++++++++++++++++++++++++ src/diffusers/loaders/lora_base.py | 7 +++ tests/lora/utils.py | 44 +++++++++++++++++ 3 files changed, 129 insertions(+) create mode 100644 better_set_adapters.patch diff --git a/better_set_adapters.patch b/better_set_adapters.patch new file mode 100644 index 000000000000..27a703e8e1af --- /dev/null +++ b/better_set_adapters.patch @@ -0,0 +1,78 @@ +diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py +index 89bb498a3..8ecd4d459 100644 +--- a/src/diffusers/loaders/lora_base.py ++++ b/src/diffusers/loaders/lora_base.py +@@ -532,6 +532,11 @@ class LoraBaseMixin: + ) + + list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]} ++ current_adapter_names = {adapter for _, adapter_list in list_adapters.items() for adapter in adapter_list} ++ for input_adapter_name in adapter_names: ++ if input_adapter_name not in current_adapter_names: ++ raise ValueError(f"Adapter name {input_adapter_name} not in the list of present adapters: {current_adapter_names}.") ++ + all_adapters = { + adapter for adapters in list_adapters.values() for adapter in adapters + } # eg ["adapter1", "adapter2"] +diff --git a/tests/lora/utils.py b/tests/lora/utils.py +index 939b749c2..163260709 100644 +--- a/tests/lora/utils.py ++++ b/tests/lora/utils.py +@@ -929,12 +929,16 @@ class PeftLoraLoaderMixinTests: + + pipe.set_adapters("adapter-1") + output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] ++ self.assertFalse(np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3), "Adapter outputs should be different.") ++ + + pipe.set_adapters("adapter-2") + output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] ++ self.assertFalse(np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3), "Adapter outputs should be different.") + + pipe.set_adapters(["adapter-1", "adapter-2"]) + output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] ++ self.assertFalse(np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3), "Adapter outputs should be different.") + + # Fuse and unfuse should lead to the same results + self.assertFalse( +@@ -960,6 +964,40 @@ class PeftLoraLoaderMixinTests: + "output with no lora and output with lora disabled should give same results", + ) + ++ def test_wrong_adapter_name_raises_error(self): ++ scheduler_cls = self.scheduler_classes[0] ++ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) ++ pipe = self.pipeline_class(**components) ++ pipe = pipe.to(torch_device) ++ pipe.set_progress_bar_config(disable=None) ++ _, _, inputs = self.get_dummy_inputs(with_generator=False) ++ ++ if "text_encoder" in self.pipeline_class._lora_loadable_modules: ++ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") ++ self.assertTrue( ++ check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" ++ ) ++ ++ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet ++ denoiser.add_adapter(denoiser_lora_config, "adapter-1") ++ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") ++ ++ if self.has_two_text_encoders or self.has_three_text_encoders: ++ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: ++ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") ++ self.assertTrue( ++ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ++ ) ++ ++ with self.assertRaises(ValueError) as err_context: ++ pipe.set_adapters("test") ++ ++ self.assertTrue("not in the list of present adapters" in str(err_context.exception)) ++ ++ # test this works. ++ pipe.set_adapters("adapter-1") ++ _ = pipe(**inputs, generator=torch.manual_seed(0))[0] ++ + def test_simple_inference_with_text_denoiser_block_scale(self): + """ + Tests a simple inference with lora attached to text encoder and unet, attaches diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 89bb498a3acd..bee8350d2856 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -532,6 +532,13 @@ def set_adapters( ) list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]} + current_adapter_names = {adapter for _, adapter_list in list_adapters.items() for adapter in adapter_list} + for input_adapter_name in adapter_names: + if input_adapter_name not in current_adapter_names: + raise ValueError( + f"Adapter name {input_adapter_name} not in the list of present adapters: {current_adapter_names}." + ) + all_adapters = { adapter for adapters in list_adapters.values() for adapter in adapters } # eg ["adapter1", "adapter2"] diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 939b749c286a..43c45daaa322 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -929,12 +929,24 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self): pipe.set_adapters("adapter-1") output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertFalse( + np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3), + "Adapter outputs should be different.", + ) pipe.set_adapters("adapter-2") output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertFalse( + np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3), + "Adapter outputs should be different.", + ) pipe.set_adapters(["adapter-1", "adapter-2"]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertFalse( + np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3), + "Adapter outputs should be different.", + ) # Fuse and unfuse should lead to the same results self.assertFalse( @@ -960,6 +972,38 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self): "output with no lora and output with lora disabled should give same results", ) + def test_wrong_adapter_name_raises_error(self): + scheduler_cls = self.scheduler_classes[0] + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + with self.assertRaises(ValueError) as err_context: + pipe.set_adapters("test") + + self.assertTrue("not in the list of present adapters" in str(err_context.exception)) + + # test this works. + pipe.set_adapters("adapter-1") + _ = pipe(**inputs, generator=torch.manual_seed(0))[0] + def test_simple_inference_with_text_denoiser_block_scale(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches From 1fa581c0bdfd2eeb786a89712ebdf8f82d0c80bc Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 26 Sep 2024 11:35:59 +0530 Subject: [PATCH 2/4] remove patch --- better_set_adapters.patch | 78 --------------------------------------- 1 file changed, 78 deletions(-) delete mode 100644 better_set_adapters.patch diff --git a/better_set_adapters.patch b/better_set_adapters.patch deleted file mode 100644 index 27a703e8e1af..000000000000 --- a/better_set_adapters.patch +++ /dev/null @@ -1,78 +0,0 @@ -diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py -index 89bb498a3..8ecd4d459 100644 ---- a/src/diffusers/loaders/lora_base.py -+++ b/src/diffusers/loaders/lora_base.py -@@ -532,6 +532,11 @@ class LoraBaseMixin: - ) - - list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]} -+ current_adapter_names = {adapter for _, adapter_list in list_adapters.items() for adapter in adapter_list} -+ for input_adapter_name in adapter_names: -+ if input_adapter_name not in current_adapter_names: -+ raise ValueError(f"Adapter name {input_adapter_name} not in the list of present adapters: {current_adapter_names}.") -+ - all_adapters = { - adapter for adapters in list_adapters.values() for adapter in adapters - } # eg ["adapter1", "adapter2"] -diff --git a/tests/lora/utils.py b/tests/lora/utils.py -index 939b749c2..163260709 100644 ---- a/tests/lora/utils.py -+++ b/tests/lora/utils.py -@@ -929,12 +929,16 @@ class PeftLoraLoaderMixinTests: - - pipe.set_adapters("adapter-1") - output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] -+ self.assertFalse(np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3), "Adapter outputs should be different.") -+ - - pipe.set_adapters("adapter-2") - output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] -+ self.assertFalse(np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3), "Adapter outputs should be different.") - - pipe.set_adapters(["adapter-1", "adapter-2"]) - output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] -+ self.assertFalse(np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3), "Adapter outputs should be different.") - - # Fuse and unfuse should lead to the same results - self.assertFalse( -@@ -960,6 +964,40 @@ class PeftLoraLoaderMixinTests: - "output with no lora and output with lora disabled should give same results", - ) - -+ def test_wrong_adapter_name_raises_error(self): -+ scheduler_cls = self.scheduler_classes[0] -+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) -+ pipe = self.pipeline_class(**components) -+ pipe = pipe.to(torch_device) -+ pipe.set_progress_bar_config(disable=None) -+ _, _, inputs = self.get_dummy_inputs(with_generator=False) -+ -+ if "text_encoder" in self.pipeline_class._lora_loadable_modules: -+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") -+ self.assertTrue( -+ check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" -+ ) -+ -+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet -+ denoiser.add_adapter(denoiser_lora_config, "adapter-1") -+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") -+ -+ if self.has_two_text_encoders or self.has_three_text_encoders: -+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: -+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") -+ self.assertTrue( -+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" -+ ) -+ -+ with self.assertRaises(ValueError) as err_context: -+ pipe.set_adapters("test") -+ -+ self.assertTrue("not in the list of present adapters" in str(err_context.exception)) -+ -+ # test this works. -+ pipe.set_adapters("adapter-1") -+ _ = pipe(**inputs, generator=torch.manual_seed(0))[0] -+ - def test_simple_inference_with_text_denoiser_block_scale(self): - """ - Tests a simple inference with lora attached to text encoder and unet, attaches From b0dc60db142dbc40e1ee5fd147ff46b43e80c641 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 26 Sep 2024 19:55:17 +0530 Subject: [PATCH 3/4] better and concise code. --- src/diffusers/loaders/lora_base.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index bee8350d2856..0e702b62dcc9 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -532,20 +532,19 @@ def set_adapters( ) list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]} - current_adapter_names = {adapter for _, adapter_list in list_adapters.items() for adapter in adapter_list} - for input_adapter_name in adapter_names: - if input_adapter_name not in current_adapter_names: - raise ValueError( - f"Adapter name {input_adapter_name} not in the list of present adapters: {current_adapter_names}." - ) + # eg ["adapter1", "adapter2"] + all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters} + missing_adapters = set(adapter_names) - all_adapters + if missing_adapters: + raise ValueError( + f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}." + ) - all_adapters = { - adapter for adapters in list_adapters.values() for adapter in adapters - } # eg ["adapter1", "adapter2"] + # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]} invert_list_adapters = { adapter: [part for part, adapters in list_adapters.items() if adapter in adapters] for adapter in all_adapters - } # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]} + } # Decompose weights into weights for denoiser and text encoders. _component_adapter_weights = {} From 33bee5d47772b0cb5ce8e95200fd7347045b2578 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 27 Sep 2024 07:20:43 +0530 Subject: [PATCH 4/4] Update src/diffusers/loaders/lora_base.py Co-authored-by: YiYi Xu --- src/diffusers/loaders/lora_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 0e702b62dcc9..e124b6eeacf3 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -535,7 +535,7 @@ def set_adapters( # eg ["adapter1", "adapter2"] all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters} missing_adapters = set(adapter_names) - all_adapters - if missing_adapters: + if len(missing_adapters) > 0: raise ValueError( f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}." )