Skip to content

Commit ccbe500

Browse files
Add ZImage LoRA test overrides for architecture differences
- Override test_lora_fuse_nan to use ZImage's 'layers' attribute instead of 'transformer_blocks' - Skip block-level LoRA scaling test (not supported in ZImage) - Add required imports: numpy, torch_device, check_if_lora_correctly_set
1 parent ef2ccb0 commit ccbe500

File tree

1 file changed

+33
-2
lines changed

1 file changed

+33
-2
lines changed

tests/lora/test_lora_layers_z_image.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import sys
1717
import unittest
1818

19+
import numpy as np
1920
import torch
2021
from peft import LoraConfig
2122
from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
@@ -27,7 +28,7 @@
2728
ZImageTransformer2DModel,
2829
)
2930

30-
from ..testing_utils import floats_tensor, require_peft_backend
31+
from ..testing_utils import floats_tensor, require_peft_backend, torch_device
3132

3233

3334
# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
@@ -42,7 +43,7 @@
4243

4344
sys.path.append(".")
4445

45-
from .utils import PeftLoraLoaderMixinTests # noqa: E402
46+
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
4647

4748

4849
@require_peft_backend
@@ -197,3 +198,33 @@ def test_simple_inference_with_text_lora_fused(self):
197198
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
198199
def test_simple_inference_with_text_lora_save_load(self):
199200
pass
201+
202+
@unittest.skip("Not supported in ZImage.")
203+
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
204+
pass
205+
206+
def test_lora_fuse_nan(self):
207+
"""Override to use ZImage's 'layers' attribute instead of 'transformer_blocks'."""
208+
components, _, denoiser_lora_config = self.get_dummy_components()
209+
pipe = self.pipeline_class(**components)
210+
pipe = pipe.to(torch_device)
211+
pipe.set_progress_bar_config(disable=None)
212+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
213+
214+
denoiser = pipe.transformer
215+
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
216+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
217+
218+
# corrupt one LoRA weight with `inf` values - ZImage uses 'layers.X.attention'
219+
with torch.no_grad():
220+
pipe.transformer.layers[0].attention.to_q.lora_A["adapter-1"].weight += float("inf")
221+
222+
# with `safe_fusing=True` we should see an Error
223+
with self.assertRaises(ValueError):
224+
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
225+
226+
# without we should not see an error, but every image will be black
227+
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
228+
out = pipe(**inputs)[0]
229+
230+
self.assertTrue(np.isnan(out).all())

0 commit comments

Comments
 (0)