Skip to content

Commit 35744eb

Browse files
committed
up
1 parent 615a420 commit 35744eb

File tree

4 files changed

+33
-24
lines changed

4 files changed

+33
-24
lines changed

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def forward(self, video_fhw, txt_seq_lens, device):
219219
video_freq = self.rope_cache[rope_key]
220220
else:
221221
video_freq = self._compute_video_freqs(frame, height, width, idx)
222+
video_freq = video_freq.to(device)
222223
vid_freqs.append(video_freq)
223224

224225
if self.scale_rope:
@@ -249,8 +250,9 @@ def _compute_video_freqs(self, frame, height, width, idx=0):
249250
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
250251

251252
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
252-
return freqs.clone().contiguous()
253-
253+
freqs = freqs.clone().contiguous()
254+
255+
return freqs
254256

255257
class QwenDoubleStreamAttnProcessor2_0:
256258
"""

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def calculate_dimensions(target_area, ratio):
183183

184184
class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
185185
r"""
186-
The QwenImage pipeline for text-to-image generation.
186+
The Qwen-Image-Edit pipeline for image editing.
187187
188188
Args:
189189
transformer ([`QwenImageTransformer2DModel`]):
@@ -222,8 +222,8 @@ def __init__(
222222
transformer=transformer,
223223
scheduler=scheduler,
224224
)
225-
self.latent_channels = 16
226225
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
226+
self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
227227
# QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
228228
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
229229
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
@@ -258,7 +258,7 @@ def _get_qwen_prompt_embeds(
258258
template = self.prompt_template_encode
259259
drop_idx = self.prompt_template_encode_start_idx
260260
txt = [template.format(e) for e in prompt]
261-
261+
262262
model_inputs = self.processor(
263263
text=txt,
264264
images=image,
@@ -640,7 +640,9 @@ def __call__(
640640
[`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
641641
returning a tuple, the first element is a list with the generated images.
642642
"""
643-
calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image.width / image.height)
643+
image_size = image[0].size if isinstance(image, list) else image.size
644+
width, height = image_size
645+
calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, width / height)
644646
height = height or calculated_height
645647
width = width or calculated_width
646648

tests/models/transformers/test_models_transformer_qwenimage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,4 @@ def prepare_dummy_input(self, height, width):
103103

104104
@pytest.mark.xfail(condition=True, reason="RoPE needs to be revisited.", strict=True)
105105
def test_torch_compile_recompilation_and_graph_break(self):
106-
pass
106+
super().test_torch_compile_recompilation_and_graph_break()

tests/pipelines/qwenimage/test_qwenimage_edit.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@
1313
# limitations under the License.
1414

1515
import unittest
16-
16+
import pytest
1717
import numpy as np
1818
import torch
1919
from PIL import Image
20-
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
20+
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
2121

2222
from diffusers import (
2323
AutoencoderKLQwenImage,
2424
FlowMatchEulerDiscreteScheduler,
25-
QwenImagePipeline,
25+
QwenImageEditPipeline,
2626
QwenImageTransformer2DModel,
2727
)
2828
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
@@ -34,12 +34,12 @@
3434
enable_full_determinism()
3535

3636

37-
class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
38-
pipeline_class = QwenImagePipeline
37+
class QwenImageEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
38+
pipeline_class = QwenImageEditPipeline
3939
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
40-
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
41-
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
42-
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
40+
batch_params = frozenset(["prompt", "image"])
41+
image_params = frozenset(["image"])
42+
image_latents_params = frozenset(["latents"])
4343
required_optional_params = frozenset(
4444
[
4545
"num_inference_steps",
@@ -56,6 +56,8 @@ class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
5656
test_group_offloading = True
5757

5858
def get_dummy_components(self):
59+
tiny_ckpt_id = "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
60+
5961
torch.manual_seed(0)
6062
transformer = QwenImageTransformer2DModel(
6163
patch_size=2,
@@ -77,10 +79,8 @@ def get_dummy_components(self):
7779
dim_mult=[1, 2, 4],
7880
num_res_blocks=1,
7981
temperal_downsample=[False, True],
80-
# fmt: off
81-
latents_mean=[0.0] * 4,
82-
latents_std=[1.0] * 4,
83-
# fmt: on
82+
latents_mean=[0.0] * z_dim,
83+
latents_std=[1.0] * z_dim,
8484
)
8585

8686
torch.manual_seed(0)
@@ -115,14 +115,15 @@ def get_dummy_components(self):
115115
vision_token_id=151654,
116116
)
117117
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
118-
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
118+
tokenizer = Qwen2Tokenizer.from_pretrained(tiny_ckpt_id)
119119

120120
components = {
121121
"transformer": transformer,
122122
"vae": vae,
123123
"scheduler": scheduler,
124124
"text_encoder": text_encoder,
125125
"tokenizer": tokenizer,
126+
"processor": Qwen2VLProcessor.from_pretrained(tiny_ckpt_id),
126127
}
127128
return components
128129

@@ -134,7 +135,7 @@ def get_dummy_inputs(self, device, seed=0):
134135

135136
inputs = {
136137
"prompt": "dance monkey",
137-
"image": Image.new("RGB", (16, 16)),
138+
"image": Image.new("RGB", (32, 32)),
138139
"negative_prompt": "bad quality",
139140
"generator": generator,
140141
"num_inference_steps": 2,
@@ -160,13 +161,13 @@ def test_inference(self):
160161
generated_image = image[0]
161162
self.assertEqual(generated_image.shape, (3, 32, 32))
162163

163-
# fmt: off
164-
expected_slice = torch.tensor([0.56331, 0.63677, 0.6015, 0.56369, 0.58166, 0.55277, 0.57176, 0.63261, 0.41466, 0.35561, 0.56229, 0.48334, 0.49714, 0.52622, 0.40872, 0.50208])
164+
expected_slice = torch.tensor(
165+
[[0.5637, 0.6341, 0.6001, 0.5620, 0.5794, 0.5498, 0.5757, 0.6389, 0.4174,
166+
0.3597, 0.5649, 0.4894, 0.4969, 0.5255, 0.4083, 0.4986]])
165167
# fmt: on
166168

167169
generated_slice = generated_image.flatten()
168170
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
169-
print(f"{generated_slice=}")
170171
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
171172

172173
def test_inference_batch_single_identical(self):
@@ -236,3 +237,7 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
236237
expected_diff_max,
237238
"VAE tiling should not affect the inference results",
238239
)
240+
241+
@pytest.mark.xfail(condition=True, reason="Preconfigured embeddings need to be revisited.", strict=True)
242+
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4):
243+
super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol)

0 commit comments

Comments
 (0)