Skip to content

Commit af24d9d

Browse files
authored
Merge branch 'main' into update-skyreels-v2
2 parents df1f6b7 + 22b229b commit af24d9d

File tree

1 file changed

+339
-0
lines changed

1 file changed

+339
-0
lines changed
Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
# Copyright 2025 The HuggingFace Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import torch
19+
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
20+
21+
from diffusers import (
22+
AutoencoderKLQwenImage,
23+
FlowMatchEulerDiscreteScheduler,
24+
QwenImageControlNetModel,
25+
QwenImageControlNetPipeline,
26+
QwenImageMultiControlNetModel,
27+
QwenImageTransformer2DModel,
28+
)
29+
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
30+
from diffusers.utils.torch_utils import randn_tensor
31+
32+
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
33+
from ..test_pipelines_common import PipelineTesterMixin, to_np
34+
35+
36+
enable_full_determinism()
37+
38+
39+
class QwenControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
40+
pipeline_class = QwenImageControlNetPipeline
41+
params = (TEXT_TO_IMAGE_PARAMS | frozenset(["control_image", "controlnet_conditioning_scale"])) - {
42+
"cross_attention_kwargs"
43+
}
44+
batch_params = frozenset(["prompt", "negative_prompt", "control_image"])
45+
image_params = frozenset(["control_image"])
46+
image_latents_params = frozenset(["latents"])
47+
48+
required_optional_params = frozenset(
49+
[
50+
"num_inference_steps",
51+
"generator",
52+
"latents",
53+
"control_image",
54+
"controlnet_conditioning_scale",
55+
"return_dict",
56+
"callback_on_step_end",
57+
"callback_on_step_end_tensor_inputs",
58+
]
59+
)
60+
61+
supports_dduf = False
62+
test_xformers_attention = True
63+
test_layerwise_casting = True
64+
test_group_offloading = True
65+
66+
def get_dummy_components(self):
67+
torch.manual_seed(0)
68+
transformer = QwenImageTransformer2DModel(
69+
patch_size=2,
70+
in_channels=16,
71+
out_channels=4,
72+
num_layers=2,
73+
attention_head_dim=16,
74+
num_attention_heads=3,
75+
joint_attention_dim=16,
76+
guidance_embeds=False,
77+
axes_dims_rope=(8, 4, 4),
78+
)
79+
80+
torch.manual_seed(0)
81+
controlnet = QwenImageControlNetModel(
82+
patch_size=2,
83+
in_channels=16,
84+
out_channels=4,
85+
num_layers=2,
86+
attention_head_dim=16,
87+
num_attention_heads=3,
88+
joint_attention_dim=16,
89+
axes_dims_rope=(8, 4, 4),
90+
)
91+
92+
torch.manual_seed(0)
93+
z_dim = 4
94+
vae = AutoencoderKLQwenImage(
95+
base_dim=z_dim * 6,
96+
z_dim=z_dim,
97+
dim_mult=[1, 2, 4],
98+
num_res_blocks=1,
99+
temperal_downsample=[False, True],
100+
latents_mean=[0.0] * z_dim,
101+
latents_std=[1.0] * z_dim,
102+
)
103+
104+
torch.manual_seed(0)
105+
scheduler = FlowMatchEulerDiscreteScheduler()
106+
107+
torch.manual_seed(0)
108+
config = Qwen2_5_VLConfig(
109+
text_config={
110+
"hidden_size": 16,
111+
"intermediate_size": 16,
112+
"num_hidden_layers": 2,
113+
"num_attention_heads": 2,
114+
"num_key_value_heads": 2,
115+
"rope_scaling": {
116+
"mrope_section": [1, 1, 2],
117+
"rope_type": "default",
118+
"type": "default",
119+
},
120+
"rope_theta": 1_000_000.0,
121+
},
122+
vision_config={
123+
"depth": 2,
124+
"hidden_size": 16,
125+
"intermediate_size": 16,
126+
"num_heads": 2,
127+
"out_hidden_size": 16,
128+
},
129+
hidden_size=16,
130+
vocab_size=152064,
131+
vision_end_token_id=151653,
132+
vision_start_token_id=151652,
133+
vision_token_id=151654,
134+
)
135+
136+
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
137+
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
138+
139+
components = {
140+
"transformer": transformer,
141+
"vae": vae,
142+
"scheduler": scheduler,
143+
"text_encoder": text_encoder,
144+
"tokenizer": tokenizer,
145+
"controlnet": controlnet,
146+
}
147+
return components
148+
149+
def get_dummy_inputs(self, device, seed=0):
150+
if str(device).startswith("mps"):
151+
generator = torch.manual_seed(seed)
152+
else:
153+
generator = torch.Generator(device=device).manual_seed(seed)
154+
155+
control_image = randn_tensor(
156+
(1, 3, 32, 32),
157+
generator=generator,
158+
device=torch.device(device),
159+
dtype=torch.float32,
160+
)
161+
162+
inputs = {
163+
"prompt": "dance monkey",
164+
"negative_prompt": "bad quality",
165+
"generator": generator,
166+
"num_inference_steps": 2,
167+
"guidance_scale": 3.0,
168+
"true_cfg_scale": 1.0,
169+
"height": 32,
170+
"width": 32,
171+
"max_sequence_length": 16,
172+
"control_image": control_image,
173+
"controlnet_conditioning_scale": 0.5,
174+
"output_type": "pt",
175+
}
176+
177+
return inputs
178+
179+
def test_qwen_controlnet(self):
180+
device = "cpu"
181+
components = self.get_dummy_components()
182+
pipe = self.pipeline_class(**components)
183+
pipe.to(device)
184+
pipe.set_progress_bar_config(disable=None)
185+
186+
inputs = self.get_dummy_inputs(device)
187+
image = pipe(**inputs).images
188+
generated_image = image[0]
189+
self.assertEqual(generated_image.shape, (3, 32, 32))
190+
191+
# Expected slice from the generated image
192+
expected_slice = torch.tensor(
193+
[
194+
0.4726,
195+
0.5549,
196+
0.6324,
197+
0.6548,
198+
0.4968,
199+
0.4639,
200+
0.4749,
201+
0.4898,
202+
0.4725,
203+
0.4645,
204+
0.4435,
205+
0.3339,
206+
0.3400,
207+
0.4630,
208+
0.3879,
209+
0.4406,
210+
]
211+
)
212+
213+
generated_slice = generated_image.flatten()
214+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
215+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
216+
217+
def test_qwen_controlnet_multicondition(self):
218+
device = "cpu"
219+
components = self.get_dummy_components()
220+
221+
components["controlnet"] = QwenImageMultiControlNetModel([components["controlnet"]])
222+
223+
pipe = self.pipeline_class(**components)
224+
pipe.to(device)
225+
pipe.set_progress_bar_config(disable=None)
226+
227+
inputs = self.get_dummy_inputs(device)
228+
control_image = inputs["control_image"]
229+
inputs["control_image"] = [control_image, control_image]
230+
inputs["controlnet_conditioning_scale"] = [0.5, 0.5]
231+
232+
image = pipe(**inputs).images
233+
generated_image = image[0]
234+
self.assertEqual(generated_image.shape, (3, 32, 32))
235+
# Expected slice from the generated image
236+
expected_slice = torch.tensor(
237+
[
238+
0.6239,
239+
0.6642,
240+
0.5768,
241+
0.6039,
242+
0.5270,
243+
0.5070,
244+
0.5006,
245+
0.5271,
246+
0.4506,
247+
0.3085,
248+
0.3435,
249+
0.5152,
250+
0.5096,
251+
0.5422,
252+
0.4286,
253+
0.5752,
254+
]
255+
)
256+
257+
generated_slice = generated_image.flatten()
258+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
259+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
260+
261+
def test_attention_slicing_forward_pass(
262+
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
263+
):
264+
if not self.test_attention_slicing:
265+
return
266+
267+
components = self.get_dummy_components()
268+
pipe = self.pipeline_class(**components)
269+
for component in pipe.components.values():
270+
if hasattr(component, "set_default_attn_processor"):
271+
component.set_default_attn_processor()
272+
pipe.to(torch_device)
273+
pipe.set_progress_bar_config(disable=None)
274+
275+
generator_device = "cpu"
276+
inputs = self.get_dummy_inputs(generator_device)
277+
output_without_slicing = pipe(**inputs)[0]
278+
279+
pipe.enable_attention_slicing(slice_size=1)
280+
inputs = self.get_dummy_inputs(generator_device)
281+
output_with_slicing1 = pipe(**inputs)[0]
282+
283+
pipe.enable_attention_slicing(slice_size=2)
284+
inputs = self.get_dummy_inputs(generator_device)
285+
output_with_slicing2 = pipe(**inputs)[0]
286+
287+
if test_max_difference:
288+
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
289+
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
290+
self.assertLess(
291+
max(max_diff1, max_diff2),
292+
expected_max_diff,
293+
"Attention slicing should not affect the inference results",
294+
)
295+
296+
def test_inference_batch_single_identical(self):
297+
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
298+
299+
def test_vae_tiling(self, expected_diff_max: float = 0.2):
300+
generator_device = "cpu"
301+
components = self.get_dummy_components()
302+
303+
pipe = self.pipeline_class(**components)
304+
pipe.to("cpu")
305+
pipe.set_progress_bar_config(disable=None)
306+
307+
# Without tiling
308+
inputs = self.get_dummy_inputs(generator_device)
309+
inputs["height"] = inputs["width"] = 128
310+
inputs["control_image"] = randn_tensor(
311+
(1, 3, 128, 128),
312+
generator=inputs["generator"],
313+
device=torch.device(generator_device),
314+
dtype=torch.float32,
315+
)
316+
output_without_tiling = pipe(**inputs)[0]
317+
318+
# With tiling
319+
pipe.vae.enable_tiling(
320+
tile_sample_min_height=96,
321+
tile_sample_min_width=96,
322+
tile_sample_stride_height=64,
323+
tile_sample_stride_width=64,
324+
)
325+
inputs = self.get_dummy_inputs(generator_device)
326+
inputs["height"] = inputs["width"] = 128
327+
inputs["control_image"] = randn_tensor(
328+
(1, 3, 128, 128),
329+
generator=inputs["generator"],
330+
device=torch.device(generator_device),
331+
dtype=torch.float32,
332+
)
333+
output_with_tiling = pipe(**inputs)[0]
334+
335+
self.assertLess(
336+
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
337+
expected_diff_max,
338+
"VAE tiling should not affect the inference results",
339+
)

0 commit comments

Comments
 (0)