|
39 | 39 | class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): |
40 | 40 | pipeline_class = Lumina2Pipeline |
41 | 41 | scheduler_cls = FlowMatchEulerDiscreteScheduler |
42 | | - scheduler_classes = [FlowMatchEulerDiscreteScheduler] |
43 | 42 | scheduler_kwargs = {} |
44 | 43 |
|
45 | 44 | transformer_kwargs = { |
@@ -141,33 +140,30 @@ def test_simple_inference_with_text_lora_save_load(self): |
141 | 140 | strict=False, |
142 | 141 | ) |
143 | 142 | def test_lora_fuse_nan(self): |
144 | | - for scheduler_cls in self.scheduler_classes: |
145 | | - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) |
146 | | - pipe = self.pipeline_class(**components) |
147 | | - pipe = pipe.to(torch_device) |
148 | | - pipe.set_progress_bar_config(disable=None) |
149 | | - _, _, inputs = self.get_dummy_inputs(with_generator=False) |
150 | | - |
151 | | - if "text_encoder" in self.pipeline_class._lora_loadable_modules: |
152 | | - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") |
153 | | - self.assertTrue( |
154 | | - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" |
155 | | - ) |
156 | | - |
157 | | - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet |
158 | | - denoiser.add_adapter(denoiser_lora_config, "adapter-1") |
159 | | - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") |
160 | | - |
161 | | - # corrupt one LoRA weight with `inf` values |
162 | | - with torch.no_grad(): |
163 | | - pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") |
164 | | - |
165 | | - # with `safe_fusing=True` we should see an Error |
166 | | - with self.assertRaises(ValueError): |
167 | | - pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) |
168 | | - |
169 | | - # without we should not see an error, but every image will be black |
170 | | - pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) |
171 | | - out = pipe(**inputs)[0] |
172 | | - |
173 | | - self.assertTrue(np.isnan(out).all()) |
| 143 | + components, text_lora_config, denoiser_lora_config = self.get_dummy_components() |
| 144 | + pipe = self.pipeline_class(**components) |
| 145 | + pipe = pipe.to(torch_device) |
| 146 | + pipe.set_progress_bar_config(disable=None) |
| 147 | + _, _, inputs = self.get_dummy_inputs(with_generator=False) |
| 148 | + |
| 149 | + if "text_encoder" in self.pipeline_class._lora_loadable_modules: |
| 150 | + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") |
| 151 | + self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") |
| 152 | + |
| 153 | + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet |
| 154 | + denoiser.add_adapter(denoiser_lora_config, "adapter-1") |
| 155 | + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") |
| 156 | + |
| 157 | + # corrupt one LoRA weight with `inf` values |
| 158 | + with torch.no_grad(): |
| 159 | + pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") |
| 160 | + |
| 161 | + # with `safe_fusing=True` we should see an Error |
| 162 | + with self.assertRaises(ValueError): |
| 163 | + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) |
| 164 | + |
| 165 | + # without we should not see an error, but every image will be black |
| 166 | + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) |
| 167 | + out = pipe(**inputs)[0] |
| 168 | + |
| 169 | + self.assertTrue(np.isnan(out).all()) |
0 commit comments