|  | 
| 22 | 22 | import numpy as np | 
| 23 | 23 | import pytest | 
| 24 | 24 | import torch | 
|  | 25 | +from parameterized import parameterized | 
| 25 | 26 | 
 | 
| 26 | 27 | from diffusers import ( | 
| 27 | 28 |     AutoencoderKL, | 
| @@ -80,6 +81,18 @@ def initialize_dummy_state_dict(state_dict): | 
| 80 | 81 | POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"] | 
| 81 | 82 | 
 | 
| 82 | 83 | 
 | 
|  | 84 | +def determine_attention_kwargs_name(pipeline_class): | 
|  | 85 | +    call_signature_keys = inspect.signature(pipeline_class.__call__).parameters.keys() | 
|  | 86 | + | 
|  | 87 | +    # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release | 
|  | 88 | +    for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES: | 
|  | 89 | +        if possible_attention_kwargs in call_signature_keys: | 
|  | 90 | +            attention_kwargs_name = possible_attention_kwargs | 
|  | 91 | +            break | 
|  | 92 | +    assert attention_kwargs_name is not None | 
|  | 93 | +    return attention_kwargs_name | 
|  | 94 | + | 
|  | 95 | + | 
| 83 | 96 | @require_peft_backend | 
| 84 | 97 | class PeftLoraLoaderMixinTests: | 
| 85 | 98 |     pipeline_class = None | 
| @@ -440,14 +453,7 @@ def test_simple_inference_with_text_lora_and_scale(self): | 
| 440 | 453 |         Tests a simple inference with lora attached on the text encoder + scale argument | 
| 441 | 454 |         and makes sure it works as expected | 
| 442 | 455 |         """ | 
| 443 |  | -        call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() | 
| 444 |  | - | 
| 445 |  | -        # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release | 
| 446 |  | -        for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES: | 
| 447 |  | -            if possible_attention_kwargs in call_signature_keys: | 
| 448 |  | -                attention_kwargs_name = possible_attention_kwargs | 
| 449 |  | -                break | 
| 450 |  | -        assert attention_kwargs_name is not None | 
|  | 456 | +        attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) | 
| 451 | 457 | 
 | 
| 452 | 458 |         for scheduler_cls in self.scheduler_classes: | 
| 453 | 459 |             components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) | 
| @@ -803,12 +809,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): | 
| 803 | 809 |         Tests a simple inference with lora attached on the text encoder + Unet + scale argument | 
| 804 | 810 |         and makes sure it works as expected | 
| 805 | 811 |         """ | 
| 806 |  | -        call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() | 
| 807 |  | -        for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES: | 
| 808 |  | -            if possible_attention_kwargs in call_signature_keys: | 
| 809 |  | -                attention_kwargs_name = possible_attention_kwargs | 
| 810 |  | -                break | 
| 811 |  | -        assert attention_kwargs_name is not None | 
|  | 812 | +        attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) | 
| 812 | 813 | 
 | 
| 813 | 814 |         for scheduler_cls in self.scheduler_classes: | 
| 814 | 815 |             components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) | 
| @@ -1765,7 +1766,8 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi( | 
| 1765 | 1766 |             pipe.set_adapters(["adapter-1"]) | 
| 1766 | 1767 |             outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] | 
| 1767 | 1768 | 
 | 
| 1768 |  | -            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"]) | 
|  | 1769 | +            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) | 
|  | 1770 | +            # pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"]) | 
| 1769 | 1771 |             assert pipe.num_fused_loras == 1 | 
| 1770 | 1772 | 
 | 
| 1771 | 1773 |             # Fusing should still keep the LoRA layers so outpout should remain the same | 
| @@ -1803,6 +1805,62 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi( | 
| 1803 | 1805 |                 "Fused lora should not change the output", | 
| 1804 | 1806 |             ) | 
| 1805 | 1807 | 
 | 
|  | 1808 | +    @parameterized.expand([1.0, 0.8]) | 
|  | 1809 | +    def test_lora_scale_kwargs_match_fusion( | 
|  | 1810 | +        self, lora_scale, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 | 
|  | 1811 | +    ): | 
|  | 1812 | +        attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) | 
|  | 1813 | + | 
|  | 1814 | +        for scheduler_cls in self.scheduler_classes: | 
|  | 1815 | +            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) | 
|  | 1816 | +            pipe = self.pipeline_class(**components) | 
|  | 1817 | +            pipe = pipe.to(torch_device) | 
|  | 1818 | +            pipe.set_progress_bar_config(disable=None) | 
|  | 1819 | +            _, _, inputs = self.get_dummy_inputs(with_generator=False) | 
|  | 1820 | + | 
|  | 1821 | +            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | 
|  | 1822 | +            self.assertTrue(output_no_lora.shape == self.output_shape) | 
|  | 1823 | + | 
|  | 1824 | +            if "text_encoder" in self.pipeline_class._lora_loadable_modules: | 
|  | 1825 | +                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") | 
|  | 1826 | +                self.assertTrue( | 
|  | 1827 | +                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" | 
|  | 1828 | +                ) | 
|  | 1829 | + | 
|  | 1830 | +            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet | 
|  | 1831 | +            denoiser.add_adapter(denoiser_lora_config, "adapter-1") | 
|  | 1832 | +            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") | 
|  | 1833 | + | 
|  | 1834 | +            if self.has_two_text_encoders or self.has_three_text_encoders: | 
|  | 1835 | +                lora_loadable_components = self.pipeline_class._lora_loadable_modules | 
|  | 1836 | +                if "text_encoder_2" in lora_loadable_components: | 
|  | 1837 | +                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") | 
|  | 1838 | +                    self.assertTrue( | 
|  | 1839 | +                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" | 
|  | 1840 | +                    ) | 
|  | 1841 | + | 
|  | 1842 | +            pipe.set_adapters(["adapter-1"]) | 
|  | 1843 | +            attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} | 
|  | 1844 | +            outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] | 
|  | 1845 | + | 
|  | 1846 | +            pipe.fuse_lora( | 
|  | 1847 | +                components=self.pipeline_class._lora_loadable_modules, | 
|  | 1848 | +                adapter_names=["adapter-1"], | 
|  | 1849 | +                lora_scale=lora_scale, | 
|  | 1850 | +            ) | 
|  | 1851 | +            assert pipe.num_fused_loras == 1 | 
|  | 1852 | + | 
|  | 1853 | +            outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] | 
|  | 1854 | + | 
|  | 1855 | +            self.assertTrue( | 
|  | 1856 | +                np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), | 
|  | 1857 | +                "Fused lora should not change the output", | 
|  | 1858 | +            ) | 
|  | 1859 | +            self.assertFalse( | 
|  | 1860 | +                np.allclose(output_no_lora, outputs_lora_1, atol=expected_atol, rtol=expected_rtol), | 
|  | 1861 | +                "LoRA should change the output", | 
|  | 1862 | +            ) | 
|  | 1863 | + | 
| 1806 | 1864 |     @require_peft_version_greater(peft_version="0.9.0") | 
| 1807 | 1865 |     def test_simple_inference_with_dora(self): | 
| 1808 | 1866 |         for scheduler_cls in self.scheduler_classes: | 
| @@ -2007,12 +2065,7 @@ def test_logs_info_when_no_lora_keys_found(self): | 
| 2007 | 2065 | 
 | 
| 2008 | 2066 |     def test_set_adapters_match_attention_kwargs(self): | 
| 2009 | 2067 |         """Test to check if outputs after `set_adapters()` and attention kwargs match.""" | 
| 2010 |  | -        call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() | 
| 2011 |  | -        for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES: | 
| 2012 |  | -            if possible_attention_kwargs in call_signature_keys: | 
| 2013 |  | -                attention_kwargs_name = possible_attention_kwargs | 
| 2014 |  | -                break | 
| 2015 |  | -        assert attention_kwargs_name is not None | 
|  | 2068 | +        attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) | 
| 2016 | 2069 | 
 | 
| 2017 | 2070 |         for scheduler_cls in self.scheduler_classes: | 
| 2018 | 2071 |             components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) | 
|  | 
0 commit comments