|
16 | 16 | import sys |
17 | 17 | import unittest |
18 | 18 |
|
| 19 | +import numpy as np |
19 | 20 | import torch |
20 | 21 | from peft import LoraConfig |
21 | 22 | from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model |
|
27 | 28 | ZImageTransformer2DModel, |
28 | 29 | ) |
29 | 30 |
|
30 | | -from ..testing_utils import floats_tensor, require_peft_backend |
| 31 | +from ..testing_utils import floats_tensor, require_peft_backend, torch_device |
31 | 32 |
|
32 | 33 |
|
33 | 34 | # Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations |
|
42 | 43 |
|
43 | 44 | sys.path.append(".") |
44 | 45 |
|
45 | | -from .utils import PeftLoraLoaderMixinTests # noqa: E402 |
| 46 | +from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 |
46 | 47 |
|
47 | 48 |
|
48 | 49 | @require_peft_backend |
@@ -197,3 +198,33 @@ def test_simple_inference_with_text_lora_fused(self): |
197 | 198 | @unittest.skip("Text encoder LoRA is not supported in ZImage.") |
198 | 199 | def test_simple_inference_with_text_lora_save_load(self): |
199 | 200 | pass |
| 201 | + |
| 202 | + @unittest.skip("Not supported in ZImage.") |
| 203 | + def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): |
| 204 | + pass |
| 205 | + |
| 206 | + def test_lora_fuse_nan(self): |
| 207 | + """Override to use ZImage's 'layers' attribute instead of 'transformer_blocks'.""" |
| 208 | + components, _, denoiser_lora_config = self.get_dummy_components() |
| 209 | + pipe = self.pipeline_class(**components) |
| 210 | + pipe = pipe.to(torch_device) |
| 211 | + pipe.set_progress_bar_config(disable=None) |
| 212 | + _, _, inputs = self.get_dummy_inputs(with_generator=False) |
| 213 | + |
| 214 | + denoiser = pipe.transformer |
| 215 | + denoiser.add_adapter(denoiser_lora_config, "adapter-1") |
| 216 | + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") |
| 217 | + |
| 218 | + # corrupt one LoRA weight with `inf` values - ZImage uses 'layers.X.attention' |
| 219 | + with torch.no_grad(): |
| 220 | + pipe.transformer.layers[0].attention.to_q.lora_A["adapter-1"].weight += float("inf") |
| 221 | + |
| 222 | + # with `safe_fusing=True` we should see an Error |
| 223 | + with self.assertRaises(ValueError): |
| 224 | + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) |
| 225 | + |
| 226 | + # without we should not see an error, but every image will be black |
| 227 | + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) |
| 228 | + out = pipe(**inputs)[0] |
| 229 | + |
| 230 | + self.assertTrue(np.isnan(out).all()) |
0 commit comments