|  | 
| 15 | 15 | import sys | 
| 16 | 16 | import unittest | 
| 17 | 17 | 
 | 
|  | 18 | +import numpy as np | 
|  | 19 | +import pytest | 
| 18 | 20 | import torch | 
| 19 | 21 | from transformers import AutoTokenizer, GemmaForCausalLM | 
| 20 | 22 | 
 | 
|  | 
| 24 | 26 |     Lumina2Text2ImgPipeline, | 
| 25 | 27 |     Lumina2Transformer2DModel, | 
| 26 | 28 | ) | 
| 27 |  | -from diffusers.utils.testing_utils import floats_tensor, require_peft_backend | 
|  | 29 | +from diffusers.utils.testing_utils import floats_tensor, is_torch_version, require_peft_backend, skip_mps, torch_device | 
| 28 | 30 | 
 | 
| 29 | 31 | 
 | 
| 30 | 32 | sys.path.append(".") | 
| 31 | 33 | 
 | 
| 32 |  | -from utils import PeftLoraLoaderMixinTests  # noqa: E402 | 
|  | 34 | +from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set  # noqa: E402 | 
| 33 | 35 | 
 | 
| 34 | 36 | 
 | 
| 35 | 37 | @require_peft_backend | 
| @@ -130,3 +132,41 @@ def test_simple_inference_with_text_lora_fused(self): | 
| 130 | 132 |     @unittest.skip("Text encoder LoRA is not supported in Lumina2.") | 
| 131 | 133 |     def test_simple_inference_with_text_lora_save_load(self): | 
| 132 | 134 |         pass | 
|  | 135 | + | 
|  | 136 | +    @skip_mps | 
|  | 137 | +    @pytest.mark.xfail( | 
|  | 138 | +        condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), | 
|  | 139 | +        reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", | 
|  | 140 | +        strict=False, | 
|  | 141 | +    ) | 
|  | 142 | +    def test_lora_fuse_nan(self): | 
|  | 143 | +        for scheduler_cls in self.scheduler_classes: | 
|  | 144 | +            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) | 
|  | 145 | +            pipe = self.pipeline_class(**components) | 
|  | 146 | +            pipe = pipe.to(torch_device) | 
|  | 147 | +            pipe.set_progress_bar_config(disable=None) | 
|  | 148 | +            _, _, inputs = self.get_dummy_inputs(with_generator=False) | 
|  | 149 | + | 
|  | 150 | +            if "text_encoder" in self.pipeline_class._lora_loadable_modules: | 
|  | 151 | +                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") | 
|  | 152 | +                self.assertTrue( | 
|  | 153 | +                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" | 
|  | 154 | +                ) | 
|  | 155 | + | 
|  | 156 | +            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet | 
|  | 157 | +            denoiser.add_adapter(denoiser_lora_config, "adapter-1") | 
|  | 158 | +            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") | 
|  | 159 | + | 
|  | 160 | +            # corrupt one LoRA weight with `inf` values | 
|  | 161 | +            with torch.no_grad(): | 
|  | 162 | +                pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") | 
|  | 163 | + | 
|  | 164 | +            # with `safe_fusing=True` we should see an Error | 
|  | 165 | +            with self.assertRaises(ValueError): | 
|  | 166 | +                pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) | 
|  | 167 | + | 
|  | 168 | +            # without we should not see an error, but every image will be black | 
|  | 169 | +            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) | 
|  | 170 | +            out = pipe(**inputs)[0] | 
|  | 171 | + | 
|  | 172 | +            self.assertTrue(np.isnan(out).all()) | 
0 commit comments