|
15 | 15 | import sys |
16 | 16 | import unittest |
17 | 17 |
|
| 18 | +import numpy as np |
18 | 19 | import torch |
19 | 20 | from transformers import AutoTokenizer, GemmaForCausalLM |
20 | 21 |
|
|
24 | 25 | Lumina2Text2ImgPipeline, |
25 | 26 | Lumina2Transformer2DModel, |
26 | 27 | ) |
27 | | -from diffusers.utils.testing_utils import floats_tensor, require_peft_backend |
| 28 | +from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, torch_device |
28 | 29 |
|
29 | 30 |
|
30 | 31 | sys.path.append(".") |
31 | 32 |
|
32 | | -from utils import PeftLoraLoaderMixinTests # noqa: E402 |
| 33 | +from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 |
33 | 34 |
|
34 | 35 |
|
35 | 36 | @require_peft_backend |
@@ -130,3 +131,45 @@ def test_simple_inference_with_text_lora_fused(self): |
130 | 131 | @unittest.skip("Text encoder LoRA is not supported in Lumina2.") |
131 | 132 | def test_simple_inference_with_text_lora_save_load(self): |
132 | 133 | pass |
| 134 | + |
| 135 | + def test_lora_fuse_nan(self): |
| 136 | + for scheduler_cls in self.scheduler_classes: |
| 137 | + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) |
| 138 | + pipe = self.pipeline_class(**components) |
| 139 | + pipe = pipe.to(torch_device) |
| 140 | + pipe.set_progress_bar_config(disable=None) |
| 141 | + _, _, inputs = self.get_dummy_inputs(with_generator=False) |
| 142 | + |
| 143 | + if "text_encoder" in self.pipeline_class._lora_loadable_modules: |
| 144 | + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") |
| 145 | + self.assertTrue( |
| 146 | + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" |
| 147 | + ) |
| 148 | + |
| 149 | + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet |
| 150 | + denoiser.add_adapter(denoiser_lora_config, "adapter-1") |
| 151 | + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") |
| 152 | + |
| 153 | + # corrupt one LoRA weight with `inf` values |
| 154 | + with torch.no_grad(): |
| 155 | + if self.unet_kwargs: |
| 156 | + pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A[ |
| 157 | + "adapter-1" |
| 158 | + ].weight += float("inf") |
| 159 | + else: |
| 160 | + named_modules = [name for name, _ in pipe.transformer.named_modules()] |
| 161 | + has_attn1 = any("attn1" in name for name in named_modules) |
| 162 | + if has_attn1: |
| 163 | + pipe.transformer.layers[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf") |
| 164 | + else: |
| 165 | + pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") |
| 166 | + |
| 167 | + # with `safe_fusing=True` we should see an Error |
| 168 | + with self.assertRaises(ValueError): |
| 169 | + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) |
| 170 | + |
| 171 | + # without we should not see an error, but every image will be black |
| 172 | + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) |
| 173 | + out = pipe(**inputs)[0] |
| 174 | + |
| 175 | + self.assertTrue(np.isnan(out).all()) |
0 commit comments