Skip to content

Commit ea2f65d

Browse files
committed
Update pipeline_qwenimage_edit_plus.py
1 parent 693d8a3 commit ea2f65d

File tree

1 file changed

+48
-12
lines changed

1 file changed

+48
-12
lines changed

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,20 @@ def _get_qwen_prompt_embeds(
252252
drop_idx = self.prompt_template_encode_start_idx
253253
txt = [template.format(base_img_prompt + e) for e in prompt]
254254

255+
if image is None:
256+
images_for_processor = None
257+
else:
258+
# If `image` is a single image (not list) the processor will broadcast it.
259+
# If `image` is a list of conditioning images, we must repeat that list
260+
# for each prompt so processor has one entry per text example.
261+
if isinstance(image, list):
262+
images_for_processor = [image] * len(txt)
263+
else:
264+
images_for_processor = image
265+
255266
model_inputs = self.processor(
256267
text=txt,
257-
images=image,
268+
images=images_for_processor,
258269
padding=True,
259270
return_tensors="pt",
260271
).to(device)
@@ -627,7 +638,12 @@ def __call__(
627638
[`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
628639
returning a tuple, the first element is a list with the generated images.
629640
"""
630-
image_size = image[-1].size if isinstance(image, list) else image.size
641+
# Use the first image's size as the deterministic base for output dims
642+
ref_img = image[0] if isinstance(image, list) else image
643+
if isinstance(ref_img, (tuple, list)):
644+
ref_img = ref_img[0]
645+
image_size = ref_img.size
646+
631647
calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
632648
height = height or calculated_height
633649
width = width or calculated_width
@@ -673,6 +689,7 @@ def __call__(
673689
vae_image_sizes = []
674690
vae_images = []
675691
for img in image:
692+
img = img[0] if isinstance(img, (tuple, list)) else img
676693
image_width, image_height = img.size
677694
condition_width, condition_height = calculate_dimensions(
678695
CONDITION_IMAGE_SIZE, image_width / image_height
@@ -681,7 +698,10 @@ def __call__(
681698
condition_image_sizes.append((condition_width, condition_height))
682699
vae_image_sizes.append((vae_width, vae_height))
683700
condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
684-
vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))
701+
preproc = self.image_processor.preprocess(img, vae_height, vae_width)
702+
if isinstance(preproc, (tuple, list)):
703+
preproc = preproc[0]
704+
vae_images.append(preproc.unsqueeze(0))
685705

686706
has_neg_prompt = negative_prompt is not None or (
687707
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
@@ -719,6 +739,25 @@ def __call__(
719739

720740
# 4. Prepare latent variables
721741
num_channels_latents = self.transformer.config.in_channels // 4
742+
if vae_images is not None:
743+
for idx, v in enumerate(vae_images):
744+
if isinstance(v, (tuple, list)):
745+
v = v[0]
746+
747+
if not torch.is_tensor(v):
748+
v = torch.as_tensor(v)
749+
750+
if v.ndim == 5 and v.shape[1] == 1 and v.shape[2] in (1, 3):
751+
v = v.permute(0, 2, 1, 3, 4).contiguous()
752+
753+
elif v.ndim == 4 and v.shape[1] in (1, 3):
754+
v = v.unsqueeze(2)
755+
756+
elif v.ndim == 3 and v.shape[0] in (1, 3):
757+
v = v.unsqueeze(0).unsqueeze(2)
758+
759+
vae_images[idx] = v
760+
722761
latents, image_latents = self.prepare_latents(
723762
vae_images,
724763
batch_size * num_images_per_prompt,
@@ -730,15 +769,12 @@ def __call__(
730769
generator,
731770
latents,
732771
)
733-
img_shapes = [
734-
[
735-
(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
736-
*[
737-
(1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
738-
for vae_width, vae_height in vae_image_sizes
739-
],
740-
]
741-
] * batch_size
772+
base_shape = (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)
773+
per_image_shapes = [
774+
(1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
775+
for vae_width, vae_height in vae_image_sizes
776+
]
777+
img_shapes = [[base_shape, *per_image_shapes] for _ in range(batch_size)]
742778

743779
# 5. Prepare timesteps
744780
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas

0 commit comments

Comments
 (0)