Skip to content

Commit 2b2abe2

Browse files
Add LoRA test for Z-Image
1 parent 03b5bcc commit 2b2abe2

File tree

1 file changed

+177
-0
lines changed

1 file changed

+177
-0
lines changed

tests/test_lora_layers_z_image.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import os
16+
import sys
17+
import unittest
18+
19+
import torch
20+
from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
21+
22+
from diffusers import (
23+
AutoencoderKL,
24+
FlowMatchEulerDiscreteScheduler,
25+
ZImagePipeline,
26+
ZImageTransformer2DModel,
27+
)
28+
29+
from ..testing_utils import floats_tensor, require_peft_backend
30+
31+
32+
# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
33+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
34+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
35+
torch.use_deterministic_algorithms(False)
36+
torch.backends.cudnn.deterministic = True
37+
torch.backends.cudnn.benchmark = False
38+
if hasattr(torch.backends, "cuda"):
39+
torch.backends.cuda.matmul.allow_tf32 = False
40+
41+
42+
sys.path.append(".")
43+
44+
from .utils import PeftLoraLoaderMixinTests # noqa: E402
45+
46+
47+
@require_peft_backend
48+
class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
49+
pipeline_class = ZImagePipeline
50+
scheduler_cls = FlowMatchEulerDiscreteScheduler
51+
scheduler_kwargs = {}
52+
53+
transformer_kwargs = {
54+
"all_patch_size": (2,),
55+
"all_f_patch_size": (1,),
56+
"in_channels": 16,
57+
"dim": 32,
58+
"n_layers": 2,
59+
"n_refiner_layers": 1,
60+
"n_heads": 2,
61+
"n_kv_heads": 2,
62+
"norm_eps": 1e-5,
63+
"qk_norm": True,
64+
"cap_feat_dim": 16,
65+
"rope_theta": 256.0,
66+
"t_scale": 1000.0,
67+
"axes_dims": [8, 4, 4],
68+
"axes_lens": [256, 32, 32],
69+
}
70+
transformer_cls = ZImageTransformer2DModel
71+
vae_kwargs = {
72+
"in_channels": 3,
73+
"out_channels": 3,
74+
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
75+
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
76+
"block_out_channels": [32, 64],
77+
"layers_per_block": 1,
78+
"latent_channels": 16,
79+
"norm_num_groups": 32,
80+
"sample_size": 32,
81+
"scaling_factor": 0.3611,
82+
"shift_factor": 0.1159,
83+
}
84+
vae_cls = AutoencoderKL
85+
tokenizer_cls, tokenizer_id = Qwen2Tokenizer, "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
86+
text_encoder_cls, text_encoder_id = Qwen3Model, None # Will be created inline
87+
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
88+
89+
@property
90+
def output_shape(self):
91+
return (1, 8, 8, 3)
92+
93+
def get_dummy_inputs(self, with_generator=True):
94+
batch_size = 1
95+
sequence_length = 10
96+
num_channels = 4
97+
sizes = (32, 32)
98+
99+
generator = torch.manual_seed(0)
100+
noise = floats_tensor((batch_size, num_channels) + sizes)
101+
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
102+
103+
pipeline_inputs = {
104+
"prompt": "A painting of a squirrel eating a burger",
105+
"num_inference_steps": 4,
106+
"guidance_scale": 0.0,
107+
"height": 32,
108+
"width": 32,
109+
"max_sequence_length": 16,
110+
"output_type": "np",
111+
}
112+
if with_generator:
113+
pipeline_inputs.update({"generator": generator})
114+
115+
return noise, input_ids, pipeline_inputs
116+
117+
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
118+
# Override to create Qwen3Model inline since it doesn't have a pretrained tiny model
119+
torch.manual_seed(0)
120+
config = Qwen3Config(
121+
hidden_size=16,
122+
intermediate_size=16,
123+
num_hidden_layers=2,
124+
num_attention_heads=2,
125+
num_key_value_heads=2,
126+
vocab_size=151936,
127+
max_position_embeddings=512,
128+
)
129+
text_encoder = Qwen3Model(config)
130+
tokenizer = Qwen2Tokenizer.from_pretrained(self.tokenizer_id)
131+
132+
transformer = self.transformer_cls(**self.transformer_kwargs)
133+
vae = self.vae_cls(**self.vae_kwargs)
134+
135+
if scheduler_cls is None:
136+
scheduler_cls = self.scheduler_cls
137+
scheduler = scheduler_cls(**self.scheduler_kwargs)
138+
139+
return {
140+
"transformer": transformer,
141+
"vae": vae,
142+
"scheduler": scheduler,
143+
"text_encoder": text_encoder,
144+
"tokenizer": tokenizer,
145+
}
146+
147+
@unittest.skip("Not supported in ZImage.")
148+
def test_simple_inference_with_text_denoiser_block_scale(self):
149+
pass
150+
151+
@unittest.skip("Not supported in ZImage.")
152+
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
153+
pass
154+
155+
@unittest.skip("Not supported in ZImage.")
156+
def test_modify_padding_mode(self):
157+
pass
158+
159+
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
160+
def test_simple_inference_with_partial_text_lora(self):
161+
pass
162+
163+
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
164+
def test_simple_inference_with_text_lora(self):
165+
pass
166+
167+
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
168+
def test_simple_inference_with_text_lora_and_scale(self):
169+
pass
170+
171+
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
172+
def test_simple_inference_with_text_lora_fused(self):
173+
pass
174+
175+
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
176+
def test_simple_inference_with_text_lora_save_load(self):
177+
pass

0 commit comments

Comments
 (0)