|
15 | 15 | import sys |
16 | 16 | import unittest |
17 | 17 |
|
18 | | -import numpy as np |
19 | | -import pytest |
20 | 18 | import torch |
21 | 19 | from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast |
22 | 20 |
|
|
28 | 26 | ) |
29 | 27 | from diffusers.utils.testing_utils import ( |
30 | 28 | floats_tensor, |
31 | | - is_torch_version, |
32 | 29 | require_peft_backend, |
33 | 30 | skip_mps, |
34 | | - torch_device, |
35 | 31 | ) |
36 | 32 |
|
37 | 33 |
|
38 | 34 | sys.path.append(".") |
39 | 35 |
|
40 | | -from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 |
| 36 | +from utils import PeftLoraLoaderMixinTests # noqa: E402 |
41 | 37 |
|
42 | 38 |
|
43 | 39 | @require_peft_backend |
@@ -144,46 +140,6 @@ def get_dummy_inputs(self, with_generator=True): |
144 | 140 |
|
145 | 141 | return noise, input_ids, pipeline_inputs |
146 | 142 |
|
147 | | - @pytest.mark.xfail( |
148 | | - condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), |
149 | | - reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", |
150 | | - strict=True, |
151 | | - ) |
152 | | - def test_lora_fuse_nan(self): |
153 | | - for scheduler_cls in self.scheduler_classes: |
154 | | - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) |
155 | | - pipe = self.pipeline_class(**components) |
156 | | - pipe = pipe.to(torch_device) |
157 | | - pipe.set_progress_bar_config(disable=None) |
158 | | - _, _, inputs = self.get_dummy_inputs(with_generator=False) |
159 | | - |
160 | | - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") |
161 | | - |
162 | | - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") |
163 | | - |
164 | | - # corrupt one LoRA weight with `inf` values |
165 | | - with torch.no_grad(): |
166 | | - pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") |
167 | | - |
168 | | - # with `safe_fusing=True` we should see an Error |
169 | | - with self.assertRaises(ValueError): |
170 | | - pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) |
171 | | - |
172 | | - # without we should not see an error, but every image will be black |
173 | | - pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) |
174 | | - |
175 | | - out = pipe( |
176 | | - prompt=inputs["prompt"], |
177 | | - height=inputs["height"], |
178 | | - width=inputs["width"], |
179 | | - num_frames=inputs["num_frames"], |
180 | | - num_inference_steps=inputs["num_inference_steps"], |
181 | | - max_sequence_length=inputs["max_sequence_length"], |
182 | | - output_type="np", |
183 | | - )[0] |
184 | | - |
185 | | - self.assertTrue(np.isnan(out).all()) |
186 | | - |
187 | 143 | def test_simple_inference_with_text_lora_denoiser_fused_multi(self): |
188 | 144 | super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) |
189 | 145 |
|
|
0 commit comments