Skip to content

Commit 615a420

Browse files
committed
tests
1 parent df737cc commit 615a420

File tree

1 file changed

+238
-0
lines changed

1 file changed

+238
-0
lines changed
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# Copyright 2025 The HuggingFace Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import torch
19+
from PIL import Image
20+
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
21+
22+
from diffusers import (
23+
AutoencoderKLQwenImage,
24+
FlowMatchEulerDiscreteScheduler,
25+
QwenImagePipeline,
26+
QwenImageTransformer2DModel,
27+
)
28+
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
29+
30+
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
31+
from ..test_pipelines_common import PipelineTesterMixin, to_np
32+
33+
34+
enable_full_determinism()
35+
36+
37+
class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
38+
pipeline_class = QwenImagePipeline
39+
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
40+
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
41+
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
42+
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
43+
required_optional_params = frozenset(
44+
[
45+
"num_inference_steps",
46+
"generator",
47+
"latents",
48+
"return_dict",
49+
"callback_on_step_end",
50+
"callback_on_step_end_tensor_inputs",
51+
]
52+
)
53+
supports_dduf = False
54+
test_xformers_attention = False
55+
test_layerwise_casting = True
56+
test_group_offloading = True
57+
58+
def get_dummy_components(self):
59+
torch.manual_seed(0)
60+
transformer = QwenImageTransformer2DModel(
61+
patch_size=2,
62+
in_channels=16,
63+
out_channels=4,
64+
num_layers=2,
65+
attention_head_dim=16,
66+
num_attention_heads=3,
67+
joint_attention_dim=16,
68+
guidance_embeds=False,
69+
axes_dims_rope=(8, 4, 4),
70+
)
71+
72+
torch.manual_seed(0)
73+
z_dim = 4
74+
vae = AutoencoderKLQwenImage(
75+
base_dim=z_dim * 6,
76+
z_dim=z_dim,
77+
dim_mult=[1, 2, 4],
78+
num_res_blocks=1,
79+
temperal_downsample=[False, True],
80+
# fmt: off
81+
latents_mean=[0.0] * 4,
82+
latents_std=[1.0] * 4,
83+
# fmt: on
84+
)
85+
86+
torch.manual_seed(0)
87+
scheduler = FlowMatchEulerDiscreteScheduler()
88+
89+
torch.manual_seed(0)
90+
config = Qwen2_5_VLConfig(
91+
text_config={
92+
"hidden_size": 16,
93+
"intermediate_size": 16,
94+
"num_hidden_layers": 2,
95+
"num_attention_heads": 2,
96+
"num_key_value_heads": 2,
97+
"rope_scaling": {
98+
"mrope_section": [1, 1, 2],
99+
"rope_type": "default",
100+
"type": "default",
101+
},
102+
"rope_theta": 1000000.0,
103+
},
104+
vision_config={
105+
"depth": 2,
106+
"hidden_size": 16,
107+
"intermediate_size": 16,
108+
"num_heads": 2,
109+
"out_hidden_size": 16,
110+
},
111+
hidden_size=16,
112+
vocab_size=152064,
113+
vision_end_token_id=151653,
114+
vision_start_token_id=151652,
115+
vision_token_id=151654,
116+
)
117+
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
118+
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
119+
120+
components = {
121+
"transformer": transformer,
122+
"vae": vae,
123+
"scheduler": scheduler,
124+
"text_encoder": text_encoder,
125+
"tokenizer": tokenizer,
126+
}
127+
return components
128+
129+
def get_dummy_inputs(self, device, seed=0):
130+
if str(device).startswith("mps"):
131+
generator = torch.manual_seed(seed)
132+
else:
133+
generator = torch.Generator(device=device).manual_seed(seed)
134+
135+
inputs = {
136+
"prompt": "dance monkey",
137+
"image": Image.new("RGB", (16, 16)),
138+
"negative_prompt": "bad quality",
139+
"generator": generator,
140+
"num_inference_steps": 2,
141+
"true_cfg_scale": 1.0,
142+
"height": 32,
143+
"width": 32,
144+
"max_sequence_length": 16,
145+
"output_type": "pt",
146+
}
147+
148+
return inputs
149+
150+
def test_inference(self):
151+
device = "cpu"
152+
153+
components = self.get_dummy_components()
154+
pipe = self.pipeline_class(**components)
155+
pipe.to(device)
156+
pipe.set_progress_bar_config(disable=None)
157+
158+
inputs = self.get_dummy_inputs(device)
159+
image = pipe(**inputs).images
160+
generated_image = image[0]
161+
self.assertEqual(generated_image.shape, (3, 32, 32))
162+
163+
# fmt: off
164+
expected_slice = torch.tensor([0.56331, 0.63677, 0.6015, 0.56369, 0.58166, 0.55277, 0.57176, 0.63261, 0.41466, 0.35561, 0.56229, 0.48334, 0.49714, 0.52622, 0.40872, 0.50208])
165+
# fmt: on
166+
167+
generated_slice = generated_image.flatten()
168+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
169+
print(f"{generated_slice=}")
170+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
171+
172+
def test_inference_batch_single_identical(self):
173+
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
174+
175+
def test_attention_slicing_forward_pass(
176+
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
177+
):
178+
if not self.test_attention_slicing:
179+
return
180+
181+
components = self.get_dummy_components()
182+
pipe = self.pipeline_class(**components)
183+
for component in pipe.components.values():
184+
if hasattr(component, "set_default_attn_processor"):
185+
component.set_default_attn_processor()
186+
pipe.to(torch_device)
187+
pipe.set_progress_bar_config(disable=None)
188+
189+
generator_device = "cpu"
190+
inputs = self.get_dummy_inputs(generator_device)
191+
output_without_slicing = pipe(**inputs)[0]
192+
193+
pipe.enable_attention_slicing(slice_size=1)
194+
inputs = self.get_dummy_inputs(generator_device)
195+
output_with_slicing1 = pipe(**inputs)[0]
196+
197+
pipe.enable_attention_slicing(slice_size=2)
198+
inputs = self.get_dummy_inputs(generator_device)
199+
output_with_slicing2 = pipe(**inputs)[0]
200+
201+
if test_max_difference:
202+
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
203+
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
204+
self.assertLess(
205+
max(max_diff1, max_diff2),
206+
expected_max_diff,
207+
"Attention slicing should not affect the inference results",
208+
)
209+
210+
def test_vae_tiling(self, expected_diff_max: float = 0.2):
211+
generator_device = "cpu"
212+
components = self.get_dummy_components()
213+
214+
pipe = self.pipeline_class(**components)
215+
pipe.to("cpu")
216+
pipe.set_progress_bar_config(disable=None)
217+
218+
# Without tiling
219+
inputs = self.get_dummy_inputs(generator_device)
220+
inputs["height"] = inputs["width"] = 128
221+
output_without_tiling = pipe(**inputs)[0]
222+
223+
# With tiling
224+
pipe.vae.enable_tiling(
225+
tile_sample_min_height=96,
226+
tile_sample_min_width=96,
227+
tile_sample_stride_height=64,
228+
tile_sample_stride_width=64,
229+
)
230+
inputs = self.get_dummy_inputs(generator_device)
231+
inputs["height"] = inputs["width"] = 128
232+
output_with_tiling = pipe(**inputs)[0]
233+
234+
self.assertLess(
235+
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
236+
expected_diff_max,
237+
"VAE tiling should not affect the inference results",
238+
)

0 commit comments

Comments
 (0)