Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,17 @@ def _get_qwen_prompt_embeds(
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(base_img_prompt + e) for e in prompt]

if image is None:
images_for_processor = None
else:
if isinstance(image, list):
images_for_processor = [image] * len(txt)
else:
images_for_processor = image

model_inputs = self.processor(
text=txt,
images=image,
images=images_for_processor,
padding=True,
return_tensors="pt",
).to(device)
Expand Down Expand Up @@ -627,7 +635,12 @@ def __call__(
[`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is a list with the generated images.
"""
image_size = image[-1].size if isinstance(image, list) else image.size

ref_img = image[0] if isinstance(image, list) else image
if isinstance(ref_img, (tuple, list)):
ref_img = ref_img[0]
image_size = ref_img.size

calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
height = height or calculated_height
width = width or calculated_width
Expand Down Expand Up @@ -673,6 +686,7 @@ def __call__(
vae_image_sizes = []
vae_images = []
for img in image:
img = img[0] if isinstance(img, (tuple, list)) else img
image_width, image_height = img.size
condition_width, condition_height = calculate_dimensions(
CONDITION_IMAGE_SIZE, image_width / image_height
Expand All @@ -681,7 +695,10 @@ def __call__(
condition_image_sizes.append((condition_width, condition_height))
vae_image_sizes.append((vae_width, vae_height))
condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))
preproc = self.image_processor.preprocess(img, vae_height, vae_width)
if isinstance(preproc, (tuple, list)):
preproc = preproc[0]
vae_images.append(preproc.unsqueeze(0))

has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
Expand Down Expand Up @@ -719,6 +736,25 @@ def __call__(

# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
if vae_images is not None:
for idx, v in enumerate(vae_images):
if isinstance(v, (tuple, list)):
v = v[0]

if not torch.is_tensor(v):
v = torch.as_tensor(v)

if v.ndim == 5 and v.shape[1] == 1 and v.shape[2] in (1, 3):
v = v.permute(0, 2, 1, 3, 4).contiguous()

elif v.ndim == 4 and v.shape[1] in (1, 3):
v = v.unsqueeze(2)

elif v.ndim == 3 and v.shape[0] in (1, 3):
v = v.unsqueeze(0).unsqueeze(2)

vae_images[idx] = v

latents, image_latents = self.prepare_latents(
vae_images,
batch_size * num_images_per_prompt,
Expand Down