diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index ec203edf166c..65d664978fa5 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -252,9 +252,17 @@ def _get_qwen_prompt_embeds( drop_idx = self.prompt_template_encode_start_idx txt = [template.format(base_img_prompt + e) for e in prompt] + if image is None: + images_for_processor = None + else: + if isinstance(image, list): + images_for_processor = [image] * len(txt) + else: + images_for_processor = image + model_inputs = self.processor( text=txt, - images=image, + images=images_for_processor, padding=True, return_tensors="pt", ).to(device) @@ -627,7 +635,12 @@ def __call__( [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ - image_size = image[-1].size if isinstance(image, list) else image.size + + ref_img = image[0] if isinstance(image, list) else image + if isinstance(ref_img, (tuple, list)): + ref_img = ref_img[0] + image_size = ref_img.size + calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) height = height or calculated_height width = width or calculated_width @@ -673,6 +686,7 @@ def __call__( vae_image_sizes = [] vae_images = [] for img in image: + img = img[0] if isinstance(img, (tuple, list)) else img image_width, image_height = img.size condition_width, condition_height = calculate_dimensions( CONDITION_IMAGE_SIZE, image_width / image_height @@ -681,7 +695,10 @@ def __call__( condition_image_sizes.append((condition_width, condition_height)) vae_image_sizes.append((vae_width, vae_height)) condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) - vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) + preproc = self.image_processor.preprocess(img, vae_height, vae_width) + if isinstance(preproc, (tuple, list)): + preproc = preproc[0] + vae_images.append(preproc.unsqueeze(0)) has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None @@ -719,6 +736,25 @@ def __call__( # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 + if vae_images is not None: + for idx, v in enumerate(vae_images): + if isinstance(v, (tuple, list)): + v = v[0] + + if not torch.is_tensor(v): + v = torch.as_tensor(v) + + if v.ndim == 5 and v.shape[1] == 1 and v.shape[2] in (1, 3): + v = v.permute(0, 2, 1, 3, 4).contiguous() + + elif v.ndim == 4 and v.shape[1] in (1, 3): + v = v.unsqueeze(2) + + elif v.ndim == 3 and v.shape[0] in (1, 3): + v = v.unsqueeze(0).unsqueeze(2) + + vae_images[idx] = v + latents, image_latents = self.prepare_latents( vae_images, batch_size * num_images_per_prompt,