Skip to content

Commit 1dba96c

Browse files
authored
Merge branch 'main' into chroma_docs
2 parents c039c94 + 5796735 commit 1dba96c

File tree

2 files changed

+286
-1
lines changed

2 files changed

+286
-1
lines changed

docs/source/en/api/pipelines/qwenimage.md

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Qwen-Image comes in the following variants:
2626
|:----------:|:--------:|
2727
| Qwen-Image | [`Qwen/Qwen-Image`](https://huggingface.co/Qwen/Qwen-Image) |
2828
| Qwen-Image-Edit | [`Qwen/Qwen-Image-Edit`](https://huggingface.co/Qwen/Qwen-Image-Edit) |
29+
| Qwen-Image-Edit Plus | [Qwen/Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509) |
2930

3031
<Tip>
3132

@@ -96,6 +97,29 @@ The `guidance_scale` parameter in the pipeline is there to support future guidan
9697

9798
</Tip>
9899

100+
## Multi-image reference with QwenImageEditPlusPipeline
101+
102+
With [`QwenImageEditPlusPipeline`], one can provide multiple images as input reference.
103+
104+
```
105+
import torch
106+
from PIL import Image
107+
from diffusers import QwenImageEditPlusPipeline
108+
from diffusers.utils import load_image
109+
110+
pipe = QwenImageEditPlusPipeline.from_pretrained(
111+
"Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16
112+
).to("cuda")
113+
114+
image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg")
115+
image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png")
116+
image = pipe(
117+
image=[image_1, image_2],
118+
prompt="put the penguin and the cat at a game show called "Qwen Edit Plus Games"",
119+
num_inference_steps=50
120+
).images[0]
121+
```
122+
99123
## QwenImagePipeline
100124

101125
[[autodoc]] QwenImagePipeline
@@ -126,7 +150,15 @@ The `guidance_scale` parameter in the pipeline is there to support future guidan
126150
- all
127151
- __call__
128152

129-
## QwenImaggeControlNetPipeline
153+
## QwenImageControlNetPipeline
154+
155+
[[autodoc]] QwenImageControlNetPipeline
156+
- all
157+
- __call__
158+
159+
## QwenImageEditPlusPipeline
160+
161+
[[autodoc]] QwenImageEditPlusPipeline
130162
- all
131163
- __call__
132164

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
# Copyright 2025 The HuggingFace Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import pytest
19+
import torch
20+
from PIL import Image
21+
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
22+
23+
from diffusers import (
24+
AutoencoderKLQwenImage,
25+
FlowMatchEulerDiscreteScheduler,
26+
QwenImageEditPlusPipeline,
27+
QwenImageTransformer2DModel,
28+
)
29+
30+
from ...testing_utils import enable_full_determinism, torch_device
31+
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
32+
from ..test_pipelines_common import PipelineTesterMixin, to_np
33+
34+
35+
enable_full_determinism()
36+
37+
38+
class QwenImageEditPlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
39+
pipeline_class = QwenImageEditPlusPipeline
40+
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
41+
batch_params = frozenset(["prompt", "image"])
42+
image_params = frozenset(["image"])
43+
image_latents_params = frozenset(["latents"])
44+
required_optional_params = frozenset(
45+
[
46+
"num_inference_steps",
47+
"generator",
48+
"latents",
49+
"return_dict",
50+
"callback_on_step_end",
51+
"callback_on_step_end_tensor_inputs",
52+
]
53+
)
54+
supports_dduf = False
55+
test_xformers_attention = False
56+
test_layerwise_casting = True
57+
test_group_offloading = True
58+
59+
def get_dummy_components(self):
60+
tiny_ckpt_id = "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
61+
62+
torch.manual_seed(0)
63+
transformer = QwenImageTransformer2DModel(
64+
patch_size=2,
65+
in_channels=16,
66+
out_channels=4,
67+
num_layers=2,
68+
attention_head_dim=16,
69+
num_attention_heads=3,
70+
joint_attention_dim=16,
71+
guidance_embeds=False,
72+
axes_dims_rope=(8, 4, 4),
73+
)
74+
75+
torch.manual_seed(0)
76+
z_dim = 4
77+
vae = AutoencoderKLQwenImage(
78+
base_dim=z_dim * 6,
79+
z_dim=z_dim,
80+
dim_mult=[1, 2, 4],
81+
num_res_blocks=1,
82+
temperal_downsample=[False, True],
83+
latents_mean=[0.0] * z_dim,
84+
latents_std=[1.0] * z_dim,
85+
)
86+
87+
torch.manual_seed(0)
88+
scheduler = FlowMatchEulerDiscreteScheduler()
89+
90+
torch.manual_seed(0)
91+
config = Qwen2_5_VLConfig(
92+
text_config={
93+
"hidden_size": 16,
94+
"intermediate_size": 16,
95+
"num_hidden_layers": 2,
96+
"num_attention_heads": 2,
97+
"num_key_value_heads": 2,
98+
"rope_scaling": {
99+
"mrope_section": [1, 1, 2],
100+
"rope_type": "default",
101+
"type": "default",
102+
},
103+
"rope_theta": 1000000.0,
104+
},
105+
vision_config={
106+
"depth": 2,
107+
"hidden_size": 16,
108+
"intermediate_size": 16,
109+
"num_heads": 2,
110+
"out_hidden_size": 16,
111+
},
112+
hidden_size=16,
113+
vocab_size=152064,
114+
vision_end_token_id=151653,
115+
vision_start_token_id=151652,
116+
vision_token_id=151654,
117+
)
118+
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
119+
tokenizer = Qwen2Tokenizer.from_pretrained(tiny_ckpt_id)
120+
121+
components = {
122+
"transformer": transformer,
123+
"vae": vae,
124+
"scheduler": scheduler,
125+
"text_encoder": text_encoder,
126+
"tokenizer": tokenizer,
127+
"processor": Qwen2VLProcessor.from_pretrained(tiny_ckpt_id),
128+
}
129+
return components
130+
131+
def get_dummy_inputs(self, device, seed=0):
132+
if str(device).startswith("mps"):
133+
generator = torch.manual_seed(seed)
134+
else:
135+
generator = torch.Generator(device=device).manual_seed(seed)
136+
137+
image = Image.new("RGB", (32, 32))
138+
inputs = {
139+
"prompt": "dance monkey",
140+
"image": [image, image],
141+
"negative_prompt": "bad quality",
142+
"generator": generator,
143+
"num_inference_steps": 2,
144+
"true_cfg_scale": 1.0,
145+
"height": 32,
146+
"width": 32,
147+
"max_sequence_length": 16,
148+
"output_type": "pt",
149+
}
150+
151+
return inputs
152+
153+
def test_inference(self):
154+
device = "cpu"
155+
156+
components = self.get_dummy_components()
157+
pipe = self.pipeline_class(**components)
158+
pipe.to(device)
159+
pipe.set_progress_bar_config(disable=None)
160+
161+
inputs = self.get_dummy_inputs(device)
162+
image = pipe(**inputs).images
163+
generated_image = image[0]
164+
self.assertEqual(generated_image.shape, (3, 32, 32))
165+
166+
# fmt: off
167+
expected_slice = torch.tensor([[0.5637, 0.6341, 0.6001, 0.5620, 0.5794, 0.5498, 0.5757, 0.6389, 0.4174, 0.3597, 0.5649, 0.4894, 0.4969, 0.5255, 0.4083, 0.4986]])
168+
# fmt: on
169+
170+
generated_slice = generated_image.flatten()
171+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
172+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
173+
174+
def test_attention_slicing_forward_pass(
175+
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
176+
):
177+
if not self.test_attention_slicing:
178+
return
179+
180+
components = self.get_dummy_components()
181+
pipe = self.pipeline_class(**components)
182+
for component in pipe.components.values():
183+
if hasattr(component, "set_default_attn_processor"):
184+
component.set_default_attn_processor()
185+
pipe.to(torch_device)
186+
pipe.set_progress_bar_config(disable=None)
187+
188+
generator_device = "cpu"
189+
inputs = self.get_dummy_inputs(generator_device)
190+
output_without_slicing = pipe(**inputs)[0]
191+
192+
pipe.enable_attention_slicing(slice_size=1)
193+
inputs = self.get_dummy_inputs(generator_device)
194+
output_with_slicing1 = pipe(**inputs)[0]
195+
196+
pipe.enable_attention_slicing(slice_size=2)
197+
inputs = self.get_dummy_inputs(generator_device)
198+
output_with_slicing2 = pipe(**inputs)[0]
199+
200+
if test_max_difference:
201+
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
202+
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
203+
self.assertLess(
204+
max(max_diff1, max_diff2),
205+
expected_max_diff,
206+
"Attention slicing should not affect the inference results",
207+
)
208+
209+
def test_vae_tiling(self, expected_diff_max: float = 0.2):
210+
generator_device = "cpu"
211+
components = self.get_dummy_components()
212+
213+
pipe = self.pipeline_class(**components)
214+
pipe.to("cpu")
215+
pipe.set_progress_bar_config(disable=None)
216+
217+
# Without tiling
218+
inputs = self.get_dummy_inputs(generator_device)
219+
inputs["height"] = inputs["width"] = 128
220+
output_without_tiling = pipe(**inputs)[0]
221+
222+
# With tiling
223+
pipe.vae.enable_tiling(
224+
tile_sample_min_height=96,
225+
tile_sample_min_width=96,
226+
tile_sample_stride_height=64,
227+
tile_sample_stride_width=64,
228+
)
229+
inputs = self.get_dummy_inputs(generator_device)
230+
inputs["height"] = inputs["width"] = 128
231+
output_with_tiling = pipe(**inputs)[0]
232+
233+
self.assertLess(
234+
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
235+
expected_diff_max,
236+
"VAE tiling should not affect the inference results",
237+
)
238+
239+
@pytest.mark.xfail(condition=True, reason="Preconfigured embeddings need to be revisited.", strict=True)
240+
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4):
241+
super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol)
242+
243+
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
244+
def test_num_images_per_prompt():
245+
super().test_num_images_per_prompt()
246+
247+
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
248+
def test_inference_batch_consistent():
249+
super().test_inference_batch_consistent()
250+
251+
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
252+
def test_inference_batch_single_identical():
253+
super().test_inference_batch_single_identical()

0 commit comments

Comments
 (0)