diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 6c072c482020..cee6c6af1273 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -16,6 +16,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np +import PIL import torch from transformers import ( CLIPTextModel, @@ -389,10 +390,31 @@ def encode_prompt( return prompt_embeds, pooled_prompt_embeds, text_ids + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + def check_inputs( self, prompt, prompt_2, + image, height, width, prompt_embeds=None, @@ -429,6 +451,30 @@ def check_inputs( elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + if ( + isinstance(self.controlnet, FluxControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, FluxMultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." @@ -523,18 +569,20 @@ def prepare_image( do_classifier_free_guidance=False, guess_mode=False, ): - if isinstance(image, torch.Tensor): - pass - else: - image = self.image_processor.preprocess(image, height=height, width=width) - + + image = self.image_processor.preprocess(image, height=height, width=width) image_batch_size = image.shape[0] if image_batch_size == 1: - repeat_by = batch_size - else: + repeat_by = batch_size*num_images_per_prompt + elif image_batch_size == batch_size: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt + else: + raise ValueError( + "`image_batch_size` must be either 1 or equal to the prompt " + \ + f"batch size, which is {batch_size}." + ) image = image.repeat_interleave(repeat_by, dim=0) @@ -678,6 +726,7 @@ def __call__( self.check_inputs( prompt, prompt_2, + control_image, height, width, prompt_embeds=prompt_embeds, @@ -726,7 +775,7 @@ def __call__( image=control_image, width=width, height=height, - batch_size=batch_size * num_images_per_prompt, + batch_size=batch_size, num_images_per_prompt=num_images_per_prompt, device=device, dtype=self.vae.dtype, @@ -762,7 +811,7 @@ def __call__( image=control_image_, width=width, height=height, - batch_size=batch_size * num_images_per_prompt, + batch_size=batch_size, num_images_per_prompt=num_images_per_prompt, device=device, dtype=self.vae.dtype,