|  | 
| 15 | 15 | import sys | 
| 16 | 16 | import unittest | 
| 17 | 17 | 
 | 
| 18 |  | -import numpy as np | 
| 19 |  | -import pytest | 
| 20 | 18 | import torch | 
| 21 | 19 | from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast | 
| 22 | 20 | 
 | 
|  | 
| 28 | 26 | ) | 
| 29 | 27 | from diffusers.utils.testing_utils import ( | 
| 30 | 28 |     floats_tensor, | 
| 31 |  | -    is_torch_version, | 
| 32 | 29 |     require_peft_backend, | 
| 33 | 30 |     skip_mps, | 
| 34 |  | -    torch_device, | 
| 35 | 31 | ) | 
| 36 | 32 | 
 | 
| 37 | 33 | 
 | 
| 38 | 34 | sys.path.append(".") | 
| 39 | 35 | 
 | 
| 40 |  | -from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set  # noqa: E402 | 
|  | 36 | +from utils import PeftLoraLoaderMixinTests  # noqa: E402 | 
| 41 | 37 | 
 | 
| 42 | 38 | 
 | 
| 43 | 39 | @require_peft_backend | 
| @@ -144,46 +140,6 @@ def get_dummy_inputs(self, with_generator=True): | 
| 144 | 140 | 
 | 
| 145 | 141 |         return noise, input_ids, pipeline_inputs | 
| 146 | 142 | 
 | 
| 147 |  | -    @pytest.mark.xfail( | 
| 148 |  | -        condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), | 
| 149 |  | -        reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", | 
| 150 |  | -        strict=True, | 
| 151 |  | -    ) | 
| 152 |  | -    def test_lora_fuse_nan(self): | 
| 153 |  | -        for scheduler_cls in self.scheduler_classes: | 
| 154 |  | -            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) | 
| 155 |  | -            pipe = self.pipeline_class(**components) | 
| 156 |  | -            pipe = pipe.to(torch_device) | 
| 157 |  | -            pipe.set_progress_bar_config(disable=None) | 
| 158 |  | -            _, _, inputs = self.get_dummy_inputs(with_generator=False) | 
| 159 |  | - | 
| 160 |  | -            pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") | 
| 161 |  | - | 
| 162 |  | -            self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") | 
| 163 |  | - | 
| 164 |  | -            # corrupt one LoRA weight with `inf` values | 
| 165 |  | -            with torch.no_grad(): | 
| 166 |  | -                pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") | 
| 167 |  | - | 
| 168 |  | -            # with `safe_fusing=True` we should see an Error | 
| 169 |  | -            with self.assertRaises(ValueError): | 
| 170 |  | -                pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) | 
| 171 |  | - | 
| 172 |  | -            # without we should not see an error, but every image will be black | 
| 173 |  | -            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) | 
| 174 |  | - | 
| 175 |  | -            out = pipe( | 
| 176 |  | -                prompt=inputs["prompt"], | 
| 177 |  | -                height=inputs["height"], | 
| 178 |  | -                width=inputs["width"], | 
| 179 |  | -                num_frames=inputs["num_frames"], | 
| 180 |  | -                num_inference_steps=inputs["num_inference_steps"], | 
| 181 |  | -                max_sequence_length=inputs["max_sequence_length"], | 
| 182 |  | -                output_type="np", | 
| 183 |  | -            )[0] | 
| 184 |  | - | 
| 185 |  | -            self.assertTrue(np.isnan(out).all()) | 
| 186 |  | - | 
| 187 | 143 |     def test_simple_inference_with_text_lora_denoiser_fused_multi(self): | 
| 188 | 144 |         super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) | 
| 189 | 145 | 
 | 
|  | 
0 commit comments