Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
343 changes: 343 additions & 0 deletions tests/pipelines/qwenimage/test_qwenimage_controlnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,343 @@
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import torch
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer

from diffusers import (
AutoencoderKLQwenImage,
FlowMatchEulerDiscreteScheduler,
QwenImageControlNetPipeline,
QwenImageTransformer2DModel,
QwenImageControlNetModel,
QwenImageMultiControlNetModel
)

from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from diffusers.utils.torch_utils import randn_tensor

from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np


enable_full_determinism()


class QwenControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = QwenImageControlNetPipeline
params = (TEXT_TO_IMAGE_PARAMS | frozenset(["control_image", "controlnet_conditioning_scale"])) - {
"cross_attention_kwargs"
}
batch_params = frozenset(["prompt", "negative_prompt", "control_image"])
image_params = frozenset(["control_image"])
image_latents_params = frozenset(["latents"])

required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"control_image",
"controlnet_conditioning_scale",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)

supports_dduf = False
test_xformers_attention = True
test_layerwise_casting = True
test_group_offloading = True

def get_dummy_components(self):
torch.manual_seed(0)
transformer = QwenImageTransformer2DModel(
patch_size=2,
in_channels=16,
out_channels=4,
num_layers=2,
attention_head_dim=16,
num_attention_heads=3,
joint_attention_dim=16,
guidance_embeds=False,
axes_dims_rope=(8, 4, 4),
)

torch.manual_seed(0)
controlnet = QwenImageControlNetModel(
patch_size=2,
in_channels=16,
out_channels=4,
num_layers=2,
attention_head_dim=16,
num_attention_heads=3,
joint_attention_dim=16,
axes_dims_rope=(8, 4, 4),
)

torch.manual_seed(0)
z_dim = 4
vae = AutoencoderKLQwenImage(
base_dim=z_dim * 6,
z_dim=z_dim,
dim_mult=[1, 2, 4],
num_res_blocks=1,
temperal_downsample=[False, True],
latents_mean=[0.0] * z_dim,
latents_std=[1.0] * z_dim,
)

torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler()

torch.manual_seed(0)
config = Qwen2_5_VLConfig(
text_config={
"hidden_size": 16,
"intermediate_size": 16,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"rope_scaling": {
"mrope_section": [1, 1, 2],
"rope_type": "default",
"type": "default",
},
"rope_theta": 1_000_000.0,
},
vision_config={
"depth": 2,
"hidden_size": 16,
"intermediate_size": 16,
"num_heads": 2,
"out_hidden_size": 16,
},
hidden_size=16,
vocab_size=152064,
vision_end_token_id=151653,
vision_start_token_id=151652,
vision_token_id=151654,
)

text_encoder = Qwen2_5_VLForConditionalGeneration(config)
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")

components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"controlnet": controlnet,
}
return components

def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)

control_image = randn_tensor(
(1, 3, 32, 32),
generator=generator,
device=torch.device(device),
dtype=torch.float32,
)

inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 3.0,
"true_cfg_scale": 1.0,
"height": 32,
"width": 32,
"max_sequence_length": 16,
"control_image": control_image,
"controlnet_conditioning_scale": 0.5,
"output_type": "pt",
}

return inputs

def test_qwen_controlnet(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 32, 32))

# Expected slice from the generated image
expected_slice = torch.tensor(
[
0.4726,
0.5549,
0.6324,
0.6548,
0.4968,
0.4639,
0.4749,
0.4898,
0.4725,
0.4645,
0.4435,
0.3339,
0.3400,
0.4630,
0.3879,
0.4406,
]
)

generated_slice = generated_image.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))

def test_qwen_controlnet_multicondition(self):
device = "cpu"
components = self.get_dummy_components()

components["controlnet"] = QwenImageMultiControlNetModel(
[components["controlnet"]]
)

pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
control_image = inputs["control_image"]
inputs["control_image"] = [control_image, control_image]
inputs["controlnet_conditioning_scale"] = [0.5, 0.5]

image = pipe(**inputs).images
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 32, 32))
# Expected slice from the generated image
expected_slice = torch.tensor(
[
0.6239,
0.6642,
0.5768,
0.6039,
0.5270,
0.5070,
0.5006,
0.5271,
0.4506,
0.3085,
0.3435,
0.5152,
0.5096,
0.5422,
0.4286,
0.5752
]
)

generated_slice = generated_image.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))


def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return

components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]

pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]

pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]

if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)

def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)

def test_vae_tiling(self, expected_diff_max: float = 0.2):
generator_device = "cpu"
components = self.get_dummy_components()

pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)

# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
inputs["control_image"] = randn_tensor(
(1, 3, 128, 128),
generator=inputs["generator"],
device=torch.device(generator_device),
dtype=torch.float32,
)
output_without_tiling = pipe(**inputs)[0]

# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_sample_stride_height=64,
tile_sample_stride_width=64,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
inputs["control_image"] = randn_tensor(
(1, 3, 128, 128),
generator=inputs["generator"],
device=torch.device(generator_device),
dtype=torch.float32,
)
output_with_tiling = pipe(**inputs)[0]

self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
Loading