|
| 1 | +diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py |
| 2 | +index 89bb498a3..8ecd4d459 100644 |
| 3 | +--- a/src/diffusers/loaders/lora_base.py |
| 4 | ++++ b/src/diffusers/loaders/lora_base.py |
| 5 | +@@ -532,6 +532,11 @@ class LoraBaseMixin: |
| 6 | + ) |
| 7 | + |
| 8 | + list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]} |
| 9 | ++ current_adapter_names = {adapter for _, adapter_list in list_adapters.items() for adapter in adapter_list} |
| 10 | ++ for input_adapter_name in adapter_names: |
| 11 | ++ if input_adapter_name not in current_adapter_names: |
| 12 | ++ raise ValueError(f"Adapter name {input_adapter_name} not in the list of present adapters: {current_adapter_names}.") |
| 13 | ++ |
| 14 | + all_adapters = { |
| 15 | + adapter for adapters in list_adapters.values() for adapter in adapters |
| 16 | + } # eg ["adapter1", "adapter2"] |
| 17 | +diff --git a/tests/lora/utils.py b/tests/lora/utils.py |
| 18 | +index 939b749c2..163260709 100644 |
| 19 | +--- a/tests/lora/utils.py |
| 20 | ++++ b/tests/lora/utils.py |
| 21 | +@@ -929,12 +929,16 @@ class PeftLoraLoaderMixinTests: |
| 22 | + |
| 23 | + pipe.set_adapters("adapter-1") |
| 24 | + output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 25 | ++ self.assertFalse(np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3), "Adapter outputs should be different.") |
| 26 | ++ |
| 27 | + |
| 28 | + pipe.set_adapters("adapter-2") |
| 29 | + output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 30 | ++ self.assertFalse(np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3), "Adapter outputs should be different.") |
| 31 | + |
| 32 | + pipe.set_adapters(["adapter-1", "adapter-2"]) |
| 33 | + output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 34 | ++ self.assertFalse(np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3), "Adapter outputs should be different.") |
| 35 | + |
| 36 | + # Fuse and unfuse should lead to the same results |
| 37 | + self.assertFalse( |
| 38 | +@@ -960,6 +964,40 @@ class PeftLoraLoaderMixinTests: |
| 39 | + "output with no lora and output with lora disabled should give same results", |
| 40 | + ) |
| 41 | + |
| 42 | ++ def test_wrong_adapter_name_raises_error(self): |
| 43 | ++ scheduler_cls = self.scheduler_classes[0] |
| 44 | ++ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) |
| 45 | ++ pipe = self.pipeline_class(**components) |
| 46 | ++ pipe = pipe.to(torch_device) |
| 47 | ++ pipe.set_progress_bar_config(disable=None) |
| 48 | ++ _, _, inputs = self.get_dummy_inputs(with_generator=False) |
| 49 | ++ |
| 50 | ++ if "text_encoder" in self.pipeline_class._lora_loadable_modules: |
| 51 | ++ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") |
| 52 | ++ self.assertTrue( |
| 53 | ++ check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" |
| 54 | ++ ) |
| 55 | ++ |
| 56 | ++ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet |
| 57 | ++ denoiser.add_adapter(denoiser_lora_config, "adapter-1") |
| 58 | ++ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") |
| 59 | ++ |
| 60 | ++ if self.has_two_text_encoders or self.has_three_text_encoders: |
| 61 | ++ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: |
| 62 | ++ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") |
| 63 | ++ self.assertTrue( |
| 64 | ++ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" |
| 65 | ++ ) |
| 66 | ++ |
| 67 | ++ with self.assertRaises(ValueError) as err_context: |
| 68 | ++ pipe.set_adapters("test") |
| 69 | ++ |
| 70 | ++ self.assertTrue("not in the list of present adapters" in str(err_context.exception)) |
| 71 | ++ |
| 72 | ++ # test this works. |
| 73 | ++ pipe.set_adapters("adapter-1") |
| 74 | ++ _ = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 75 | ++ |
| 76 | + def test_simple_inference_with_text_denoiser_block_scale(self): |
| 77 | + """ |
| 78 | + Tests a simple inference with lora attached to text encoder and unet, attaches |
0 commit comments