Skip to content

Commit 8379627

Browse files
committed
added test qwen image controlnet
1 parent 9a7ae77 commit 8379627

File tree

1 file changed

+281
-0
lines changed

1 file changed

+281
-0
lines changed
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
# Copyright 2025 The HuggingFace Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import torch
19+
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
20+
21+
from diffusers import (
22+
AutoencoderKLQwenImage,
23+
FlowMatchEulerDiscreteScheduler,
24+
QwenImageTransformer2DModel,
25+
QwenImageControlNetPipeline
26+
)
27+
from diffusers.models.controlnets.controlnet_qwenimage import QwenImageControlNetModel
28+
from diffusers.utils import load_image
29+
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
30+
from diffusers.utils.torch_utils import randn_tensor
31+
32+
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
33+
from ..test_pipelines_common import PipelineTesterMixin, to_np, FluxIPAdapterTesterMixin
34+
35+
36+
enable_full_determinism()
37+
38+
39+
40+
class QwenControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
41+
pipeline_class = QwenImageControlNetPipeline
42+
params = (TEXT_TO_IMAGE_PARAMS | frozenset(["control_image", "controlnet_conditioning_scale"])) - {"cross_attention_kwargs"}
43+
batch_params = frozenset(["prompt", "negative_prompt", "control_image"])
44+
image_params = frozenset(["control_image"])
45+
image_latents_params = frozenset(["latents"])
46+
47+
required_optional_params = frozenset(
48+
[
49+
"num_inference_steps",
50+
"generator",
51+
"latents",
52+
"control_image",
53+
"controlnet_conditioning_scale",
54+
"return_dict",
55+
"callback_on_step_end",
56+
"callback_on_step_end_tensor_inputs",
57+
]
58+
)
59+
60+
supports_dduf = False
61+
test_xformers_attention = True
62+
test_layerwise_casting = True
63+
test_group_offloading = True
64+
65+
def get_dummy_components(self):
66+
torch.manual_seed(0)
67+
transformer = QwenImageTransformer2DModel(
68+
patch_size=2,
69+
in_channels=16,
70+
out_channels=4,
71+
num_layers=2,
72+
attention_head_dim=16,
73+
num_attention_heads=3,
74+
joint_attention_dim=16,
75+
guidance_embeds=False,
76+
axes_dims_rope=(8, 4, 4),
77+
)
78+
79+
80+
torch.manual_seed(0)
81+
controlnet = QwenImageControlNetModel(
82+
patch_size=2,
83+
in_channels=16,
84+
out_channels=4,
85+
num_layers=2,
86+
attention_head_dim=16,
87+
num_attention_heads=3,
88+
joint_attention_dim=16,
89+
axes_dims_rope=(8, 4, 4)
90+
)
91+
92+
93+
torch.manual_seed(0)
94+
z_dim = 4
95+
vae = AutoencoderKLQwenImage(
96+
base_dim=z_dim * 6,
97+
z_dim=z_dim,
98+
dim_mult=[1, 2, 4],
99+
num_res_blocks=1,
100+
temperal_downsample=[False, True],
101+
latents_mean=[0.0] * z_dim,
102+
latents_std=[1.0] * z_dim,
103+
)
104+
105+
torch.manual_seed(0)
106+
scheduler = FlowMatchEulerDiscreteScheduler()
107+
108+
torch.manual_seed(0)
109+
config = Qwen2_5_VLConfig(
110+
text_config={
111+
"hidden_size": 16,
112+
"intermediate_size": 16,
113+
"num_hidden_layers": 2,
114+
"num_attention_heads": 2,
115+
"num_key_value_heads": 2,
116+
"rope_scaling": {
117+
"mrope_section": [1, 1, 2],
118+
"rope_type": "default",
119+
"type": "default",
120+
},
121+
"rope_theta": 1_000_000.0,
122+
},
123+
vision_config={
124+
"depth": 2,
125+
"hidden_size": 16,
126+
"intermediate_size": 16,
127+
"num_heads": 2,
128+
"out_hidden_size": 16,
129+
},
130+
hidden_size=16,
131+
vocab_size=152064,
132+
vision_end_token_id=151653,
133+
vision_start_token_id=151652,
134+
vision_token_id=151654,
135+
)
136+
137+
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
138+
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
139+
140+
components = {
141+
"transformer": transformer,
142+
"vae": vae,
143+
"scheduler": scheduler,
144+
"text_encoder": text_encoder,
145+
"tokenizer": tokenizer,
146+
"controlnet": controlnet,
147+
}
148+
return components
149+
150+
def get_dummy_inputs(self, device, seed=0):
151+
if str(device).startswith("mps"):
152+
generator = torch.manual_seed(seed)
153+
else:
154+
generator = torch.Generator(device=device).manual_seed(seed)
155+
156+
control_image = randn_tensor(
157+
(1, 3, 32, 32),
158+
generator=generator,
159+
device=torch.device(device),
160+
dtype=torch.float32,
161+
)
162+
163+
inputs = {
164+
"prompt": "dance monkey",
165+
"negative_prompt": "bad quality",
166+
"generator": generator,
167+
"num_inference_steps": 2,
168+
"guidance_scale": 3.0,
169+
"true_cfg_scale": 1.0,
170+
"height": 32,
171+
"width": 32,
172+
"max_sequence_length": 16,
173+
"control_image": control_image,
174+
"controlnet_conditioning_scale": 0.5,
175+
"output_type": "pt",
176+
}
177+
178+
return inputs
179+
180+
def test_qwen_controlnet(self):
181+
device = "cpu"
182+
components = self.get_dummy_components()
183+
pipe = self.pipeline_class(**components)
184+
pipe.to(device)
185+
pipe.set_progress_bar_config(disable=None)
186+
187+
inputs = self.get_dummy_inputs(device)
188+
image = pipe(**inputs).images
189+
generated_image = image[0]
190+
self.assertEqual(generated_image.shape, (3, 32, 32))
191+
192+
# Expected slice from the generated image
193+
expected_slice = torch.tensor(
194+
[0.4726, 0.5549, 0.6324, 0.6548, 0.4968, 0.4639, 0.4749, 0.4898, 0.4725, 0.4645, 0.4435, 0.3339, 0.3400, 0.4630, 0.3879, 0.4406]
195+
)
196+
197+
generated_slice = generated_image.flatten()
198+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
199+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
200+
201+
def test_attention_slicing_forward_pass(
202+
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
203+
):
204+
if not self.test_attention_slicing:
205+
return
206+
207+
components = self.get_dummy_components()
208+
pipe = self.pipeline_class(**components)
209+
for component in pipe.components.values():
210+
if hasattr(component, "set_default_attn_processor"):
211+
component.set_default_attn_processor()
212+
pipe.to(torch_device)
213+
pipe.set_progress_bar_config(disable=None)
214+
215+
generator_device = "cpu"
216+
inputs = self.get_dummy_inputs(generator_device)
217+
output_without_slicing = pipe(**inputs)[0]
218+
219+
pipe.enable_attention_slicing(slice_size=1)
220+
inputs = self.get_dummy_inputs(generator_device)
221+
output_with_slicing1 = pipe(**inputs)[0]
222+
223+
pipe.enable_attention_slicing(slice_size=2)
224+
inputs = self.get_dummy_inputs(generator_device)
225+
output_with_slicing2 = pipe(**inputs)[0]
226+
227+
if test_max_difference:
228+
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
229+
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
230+
self.assertLess(
231+
max(max_diff1, max_diff2),
232+
expected_max_diff,
233+
"Attention slicing should not affect the inference results",
234+
)
235+
236+
def test_inference_batch_single_identical(self):
237+
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
238+
239+
def test_vae_tiling(self, expected_diff_max: float = 0.2):
240+
generator_device = "cpu"
241+
components = self.get_dummy_components()
242+
243+
pipe = self.pipeline_class(**components)
244+
pipe.to("cpu")
245+
pipe.set_progress_bar_config(disable=None)
246+
247+
# Without tiling
248+
inputs = self.get_dummy_inputs(generator_device)
249+
inputs["height"] = inputs["width"] = 128
250+
inputs["control_image"] = randn_tensor(
251+
(1, 3, 128, 128),
252+
generator=inputs["generator"],
253+
device=torch.device(generator_device),
254+
dtype=torch.float32,
255+
)
256+
output_without_tiling = pipe(**inputs)[0]
257+
258+
# With tiling
259+
pipe.vae.enable_tiling(
260+
tile_sample_min_height=96,
261+
tile_sample_min_width=96,
262+
tile_sample_stride_height=64,
263+
tile_sample_stride_width=64,
264+
)
265+
inputs = self.get_dummy_inputs(generator_device)
266+
inputs["height"] = inputs["width"] = 128
267+
inputs["control_image"] = randn_tensor(
268+
(1, 3, 128, 128),
269+
generator=inputs["generator"],
270+
device=torch.device(generator_device),
271+
dtype=torch.float32,
272+
)
273+
output_with_tiling = pipe(**inputs)[0]
274+
275+
self.assertLess(
276+
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
277+
expected_diff_max,
278+
"VAE tiling should not affect the inference results",
279+
)
280+
281+

0 commit comments

Comments
 (0)