diff --git a/src/diffusers/models/controlnets/__init__.py b/src/diffusers/models/controlnets/__init__.py index c558c40be375..ea86d669f392 100644 --- a/src/diffusers/models/controlnets/__init__.py +++ b/src/diffusers/models/controlnets/__init__.py @@ -15,7 +15,7 @@ SparseControlNetModel, SparseControlNetOutput, ) - from .controlnet_union import ControlNetUnionInput, ControlNetUnionInputProMax, ControlNetUnionModel + from .controlnet_union import ControlNetUnionModel from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel from .multicontrolnet import MultiControlNetModel diff --git a/src/diffusers/models/controlnets/controlnet_union.py b/src/diffusers/models/controlnets/controlnet_union.py index 076629200eac..fc80da76235b 100644 --- a/src/diffusers/models/controlnets/controlnet_union.py +++ b/src/diffusers/models/controlnets/controlnet_union.py @@ -11,14 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...image_processor import PipelineImageInput from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import logging from ..attention_processor import ( @@ -40,76 +38,6 @@ from .controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module -@dataclass -class ControlNetUnionInput: - """ - The image input of [`ControlNetUnionModel`]: - - - 0: openpose - - 1: depth - - 2: hed/pidi/scribble/ted - - 3: canny/lineart/anime_lineart/mlsd - - 4: normal - - 5: segment - """ - - openpose: Optional[PipelineImageInput] = None - depth: Optional[PipelineImageInput] = None - hed: Optional[PipelineImageInput] = None - canny: Optional[PipelineImageInput] = None - normal: Optional[PipelineImageInput] = None - segment: Optional[PipelineImageInput] = None - - def __len__(self) -> int: - return len(vars(self)) - - def __iter__(self): - return iter(vars(self)) - - def __getitem__(self, key): - return getattr(self, key) - - def __setitem__(self, key, value): - setattr(self, key, value) - - -@dataclass -class ControlNetUnionInputProMax: - """ - The image input of [`ControlNetUnionModel`]: - - - 0: openpose - - 1: depth - - 2: hed/pidi/scribble/ted - - 3: canny/lineart/anime_lineart/mlsd - - 4: normal - - 5: segment - - 6: tile - - 7: repaint - """ - - openpose: Optional[PipelineImageInput] = None - depth: Optional[PipelineImageInput] = None - hed: Optional[PipelineImageInput] = None - canny: Optional[PipelineImageInput] = None - normal: Optional[PipelineImageInput] = None - segment: Optional[PipelineImageInput] = None - tile: Optional[PipelineImageInput] = None - repaint: Optional[PipelineImageInput] = None - - def __len__(self) -> int: - return len(vars(self)) - - def __iter__(self): - return iter(vars(self)) - - def __getitem__(self, key): - return getattr(self, key) - - def __setitem__(self, key, value): - setattr(self, key, value) - - logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -680,8 +608,9 @@ def forward( sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, - controlnet_cond: Union[ControlNetUnionInput, ControlNetUnionInputProMax], + controlnet_cond: List[torch.Tensor], control_type: torch.Tensor, + control_type_idx: List[int], conditioning_scale: float = 1.0, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, @@ -701,11 +630,13 @@ def forward( The number of timesteps to denoise an input. encoder_hidden_states (`torch.Tensor`): The encoder hidden states. - controlnet_cond (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`): + controlnet_cond (`List[torch.Tensor]`): The conditional input tensors. control_type (`torch.Tensor`): A tensor of shape `(batch, num_control_type)` with values `0` or `1` depending on whether the control type is used. + control_type_idx (`List[int]`): + The indices of `control_type`. conditioning_scale (`float`, defaults to `1.0`): The scale factor for ControlNet outputs. class_labels (`torch.Tensor`, *optional*, defaults to `None`): @@ -733,20 +664,6 @@ def forward( If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ - if not isinstance(controlnet_cond, (ControlNetUnionInput, ControlNetUnionInputProMax)): - raise ValueError( - "Expected type of `controlnet_cond` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`" - ) - if len(controlnet_cond) != self.config.num_control_type: - if isinstance(controlnet_cond, ControlNetUnionInput): - raise ValueError( - f"Expected num_control_type {self.config.num_control_type}, got {len(controlnet_cond)}. Try `ControlNetUnionInputProMax`." - ) - elif isinstance(controlnet_cond, ControlNetUnionInputProMax): - raise ValueError( - f"Expected num_control_type {self.config.num_control_type}, got {len(controlnet_cond)}. Try `ControlNetUnionInput`." - ) - # check channel order channel_order = self.config.controlnet_conditioning_channel_order @@ -830,12 +747,10 @@ def forward( inputs = [] condition_list = [] - for idx, image_type in enumerate(controlnet_cond): - if controlnet_cond[image_type] is None: - continue - condition = self.controlnet_cond_embedding(controlnet_cond[image_type]) + for cond, control_idx in zip(controlnet_cond, control_type_idx): + condition = self.controlnet_cond_embedding(cond) feat_seq = torch.mean(condition, dim=(2, 3)) - feat_seq = feat_seq + self.task_embedding[idx] + feat_seq = feat_seq + self.task_embedding[control_idx] inputs.append(feat_seq.unsqueeze(1)) condition_list.append(condition) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py index 0465391d7305..bfc28615e8b4 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py @@ -40,7 +40,6 @@ AttnProcessor2_0, XFormersAttnProcessor, ) -from ...models.controlnets import ControlNetUnionInput, ControlNetUnionInputProMax from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -82,7 +81,6 @@ def retrieve_latents( Examples: ```py from diffusers import StableDiffusionXLControlNetUnionInpaintPipeline, ControlNetUnionModel, AutoencoderKL - from diffusers.models.controlnets import ControlNetUnionInputProMax from diffusers.utils import load_image import torch import numpy as np @@ -114,11 +112,8 @@ def retrieve_latents( mask_np = np.array(mask) controlnet_img_np[mask_np > 0] = 0 controlnet_img = Image.fromarray(controlnet_img_np) - union_input = ControlNetUnionInputProMax( - repaint=controlnet_img, - ) # generate image - image = pipe(prompt, image=image, mask_image=mask, control_image_list=union_input).images[0] + image = pipe(prompt, image=image, mask_image=mask, control_image=[controlnet_img], control_mode=[7]).images[0] image.save("inpaint.png") ``` """ @@ -1130,7 +1125,7 @@ def __call__( prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, mask_image: PipelineImageInput = None, - control_image_list: Union[ControlNetUnionInput, ControlNetUnionInputProMax] = None, + control_image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, padding_mask_crop: Optional[int] = None, @@ -1158,6 +1153,7 @@ def __call__( guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, + control_mode: Optional[Union[int, List[int]]] = None, guidance_rescale: float = 0.0, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), @@ -1345,20 +1341,6 @@ def __call__( controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - if not isinstance(control_image_list, (ControlNetUnionInput, ControlNetUnionInputProMax)): - raise ValueError( - "Expected type of `control_image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`" - ) - if len(control_image_list) != controlnet.config.num_control_type: - if isinstance(control_image_list, ControlNetUnionInput): - raise ValueError( - f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInputProMax`." - ) - elif isinstance(control_image_list, ControlNetUnionInputProMax): - raise ValueError( - f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInput`." - ) - # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] @@ -1375,36 +1357,44 @@ def __call__( elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] + if not isinstance(control_image, list): + control_image = [control_image] + + if not isinstance(control_mode, list): + control_mode = [control_mode] + + if len(control_image) != len(control_mode): + raise ValueError("Expected len(control_image) == len(control_type)") + + num_control_type = controlnet.config.num_control_type + # 1. Check inputs - control_type = [] - for image_type in control_image_list: - if control_image_list[image_type]: - self.check_inputs( - prompt, - prompt_2, - control_image_list[image_type], - mask_image, - strength, - num_inference_steps, - callback_steps, - output_type, - negative_prompt, - negative_prompt_2, - prompt_embeds, - negative_prompt_embeds, - ip_adapter_image, - ip_adapter_image_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - controlnet_conditioning_scale, - control_guidance_start, - control_guidance_end, - callback_on_step_end_tensor_inputs, - padding_mask_crop, - ) - control_type.append(1) - else: - control_type.append(0) + control_type = [0 for _ in range(num_control_type)] + for _image, control_idx in zip(control_image, control_mode): + control_type[control_idx] = 1 + self.check_inputs( + prompt, + prompt_2, + _image, + mask_image, + strength, + num_inference_steps, + callback_steps, + output_type, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + padding_mask_crop, + ) control_type = torch.Tensor(control_type) @@ -1499,23 +1489,21 @@ def denoising_value_valid(dnv): init_image = init_image.to(dtype=torch.float32) # 5.2 Prepare control images - for image_type in control_image_list: - if control_image_list[image_type]: - control_image = self.prepare_control_image( - image=control_image_list[image_type], - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - crops_coords=crops_coords, - resize_mode=resize_mode, - do_classifier_free_guidance=self.do_classifier_free_guidance, - guess_mode=guess_mode, - ) - height, width = control_image.shape[-2:] - control_image_list[image_type] = control_image + for idx, _ in enumerate(control_image): + control_image[idx] = self.prepare_control_image( + image=control_image[idx], + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = control_image[idx].shape[-2:] # 5.3 Prepare mask mask = self.mask_processor.preprocess( @@ -1589,6 +1577,9 @@ def denoising_value_valid(dnv): original_size = original_size or (height, width) target_size = target_size or (height, width) + for _image in control_image: + if isinstance(_image, torch.Tensor): + original_size = original_size or _image.shape[-2:] # 10. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds @@ -1693,8 +1684,9 @@ def denoising_value_valid(dnv): control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond=control_image_list, + controlnet_cond=control_image, control_type=control_type, + control_type_idx=control_mode, conditioning_scale=cond_scale, guess_mode=guess_mode, added_cond_kwargs=controlnet_added_cond_kwargs, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py index 58a8ba62e24e..78395243f6e4 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py @@ -43,7 +43,6 @@ AttnProcessor2_0, XFormersAttnProcessor, ) -from ...models.controlnets import ControlNetUnionInput, ControlNetUnionInputProMax from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -70,7 +69,6 @@ >>> # !pip install controlnet_aux >>> from controlnet_aux import LineartAnimeDetector >>> from diffusers import StableDiffusionXLControlNetUnionPipeline, ControlNetUnionModel, AutoencoderKL - >>> from diffusers.models.controlnets import ControlNetUnionInput >>> from diffusers.utils import load_image >>> import torch @@ -89,17 +87,14 @@ ... controlnet=controlnet, ... vae=vae, ... torch_dtype=torch.float16, + ... variant="fp16", ... ) >>> pipe.enable_model_cpu_offload() >>> # prepare image >>> processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators") >>> controlnet_img = processor(image, output_type="pil") - >>> # set ControlNetUnion input - >>> union_input = ControlNetUnionInput( - ... canny=controlnet_img, - ... ) >>> # generate image - >>> image = pipe(prompt, image=union_input).images[0] + >>> image = pipe(prompt, control_image=[controlnet_img], control_mode=[3], height=1024, width=1024).images[0] ``` """ @@ -791,26 +786,6 @@ def check_inputs( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) - def check_input( - self, - image: Union[ControlNetUnionInput, ControlNetUnionInputProMax], - ): - controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - - if not isinstance(image, (ControlNetUnionInput, ControlNetUnionInputProMax)): - raise ValueError( - "Expected type of `image` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`" - ) - if len(image) != controlnet.config.num_control_type: - if isinstance(image, ControlNetUnionInput): - raise ValueError( - f"Expected num_control_type {controlnet.config.num_control_type}, got {len(image)}. Try `ControlNetUnionInputProMax`." - ) - elif isinstance(image, ControlNetUnionInputProMax): - raise ValueError( - f"Expected num_control_type {controlnet.config.num_control_type}, got {len(image)}. Try `ControlNetUnionInput`." - ) - # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image def prepare_image( self, @@ -970,7 +945,7 @@ def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, - image: Union[ControlNetUnionInput, ControlNetUnionInputProMax] = None, + control_image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -997,6 +972,7 @@ def __call__( guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, + control_mode: Optional[Union[int, List[int]]] = None, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, @@ -1018,10 +994,7 @@ def __call__( prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders. - image (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`): - In turn this supports (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, - `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[List[torch.FloatTensor]]`, - `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + control_image (`PipelineImageInput`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or @@ -1168,38 +1141,45 @@ def __call__( controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - self.check_input(image) - # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] + if not isinstance(control_image, list): + control_image = [control_image] + + if not isinstance(control_mode, list): + control_mode = [control_mode] + + if len(control_image) != len(control_mode): + raise ValueError("Expected len(control_image) == len(control_type)") + + num_control_type = controlnet.config.num_control_type + + # 1. Check inputs + control_type = [0 for _ in range(num_control_type)] # 1. Check inputs. Raise error if not correct - control_type = [] - for image_type in image: - if image[image_type]: - self.check_inputs( - prompt, - prompt_2, - image[image_type], - negative_prompt, - negative_prompt_2, - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - ip_adapter_image, - ip_adapter_image_embeds, - negative_pooled_prompt_embeds, - controlnet_conditioning_scale, - control_guidance_start, - control_guidance_end, - callback_on_step_end_tensor_inputs, - ) - control_type.append(1) - else: - control_type.append(0) + for _image, control_idx in zip(control_image, control_mode): + control_type[control_idx] = 1 + self.check_inputs( + prompt, + prompt_2, + _image, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) control_type = torch.Tensor(control_type) @@ -1258,20 +1238,19 @@ def __call__( ) # 4. Prepare image - for image_type in image: - if image[image_type]: - image[image_type] = self.prepare_image( - image=image[image_type], - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=self.do_classifier_free_guidance, - guess_mode=guess_mode, - ) - height, width = image[image_type].shape[-2:] + for idx, _ in enumerate(control_image): + control_image[idx] = self.prepare_image( + image=control_image[idx], + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = control_image[idx].shape[-2:] # 5. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( @@ -1312,11 +1291,11 @@ def __call__( ) # 7.2 Prepare added time ids & embeddings - for image_type in image: - if isinstance(image[image_type], torch.Tensor): - original_size = original_size or image[image_type].shape[-2:] - + original_size = original_size or (height, width) target_size = target_size or (height, width) + for _image in control_image: + if isinstance(_image, torch.Tensor): + original_size = original_size or _image.shape[-2:] add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) @@ -1424,8 +1403,9 @@ def __call__( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond=image, + controlnet_cond=control_image, control_type=control_type, + control_type_idx=control_mode, conditioning_scale=cond_scale, guess_mode=guess_mode, added_cond_kwargs=controlnet_added_cond_kwargs, @@ -1478,7 +1458,6 @@ def __call__( ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) - image = callback_outputs.pop("image", image) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py index a3002eb565ff..f36212d70755 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py @@ -43,7 +43,6 @@ AttnProcessor2_0, XFormersAttnProcessor, ) -from ...models.controlnets import ControlNetUnionInput, ControlNetUnionInputProMax from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -74,7 +73,6 @@ ControlNetUnionModel, AutoencoderKL, ) - from diffusers.models.controlnets import ControlNetUnionInputProMax from diffusers.utils import load_image import torch from PIL import Image @@ -95,6 +93,7 @@ controlnet=controlnet, vae=vae, torch_dtype=torch.float16, + variant="fp16", ).to("cuda") # `enable_model_cpu_offload` is not recommended due to multiple generations height = image.height @@ -132,14 +131,12 @@ # set ControlNetUnion input result_images = [] for sub_img, crops_coords in zip(images, crops_coords_list): - union_input = ControlNetUnionInputProMax( - tile=sub_img, - ) new_width, new_height = W, H out = pipe( prompt=[prompt] * 1, image=sub_img, - control_image_list=union_input, + control_image=[sub_img], + control_mode=[6], width=new_width, height=new_height, num_inference_steps=30, @@ -1065,7 +1062,7 @@ def __call__( prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, - control_image_list: Union[ControlNetUnionInput, ControlNetUnionInputProMax] = None, + control_image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 0.8, @@ -1090,6 +1087,7 @@ def __call__( guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, + control_mode: Optional[Union[int, List[int]]] = None, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, @@ -1119,10 +1117,7 @@ def __call__( `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The initial image will be used as the starting point for the image generation process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded again. - control_image_list (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`): - In turn this supports (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, - `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[List[torch.FloatTensor]]`, - `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):: + control_image (`PipelineImageInput`): The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height @@ -1291,53 +1286,47 @@ def __call__( controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - if not isinstance(control_image_list, (ControlNetUnionInput, ControlNetUnionInputProMax)): - raise ValueError( - "Expected type of `control_image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`" - ) - if len(control_image_list) != controlnet.config.num_control_type: - if isinstance(control_image_list, ControlNetUnionInput): - raise ValueError( - f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInputProMax`." - ) - elif isinstance(control_image_list, ControlNetUnionInputProMax): - raise ValueError( - f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInput`." - ) - # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] - # 1. Check inputs. Raise error if not correct - control_type = [] - for image_type in control_image_list: - if control_image_list[image_type]: - self.check_inputs( - prompt, - prompt_2, - control_image_list[image_type], - strength, - num_inference_steps, - callback_steps, - negative_prompt, - negative_prompt_2, - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ip_adapter_image, - ip_adapter_image_embeds, - controlnet_conditioning_scale, - control_guidance_start, - control_guidance_end, - callback_on_step_end_tensor_inputs, - ) - control_type.append(1) - else: - control_type.append(0) + if not isinstance(control_image, list): + control_image = [control_image] + + if not isinstance(control_mode, list): + control_mode = [control_mode] + + if len(control_image) != len(control_mode): + raise ValueError("Expected len(control_image) == len(control_type)") + + num_control_type = controlnet.config.num_control_type + + # 1. Check inputs + control_type = [0 for _ in range(num_control_type)] + for _image, control_idx in zip(control_image, control_mode): + control_type[control_idx] = 1 + self.check_inputs( + prompt, + prompt_2, + _image, + strength, + num_inference_steps, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) control_type = torch.Tensor(control_type) @@ -1397,21 +1386,19 @@ def __call__( # 4. Prepare image and controlnet_conditioning_image image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - for image_type in control_image_list: - if control_image_list[image_type]: - control_image = self.prepare_control_image( - image=control_image_list[image_type], - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=self.do_classifier_free_guidance, - guess_mode=guess_mode, - ) - height, width = control_image.shape[-2:] - control_image_list[image_type] = control_image + for idx, _ in enumerate(control_image): + control_image[idx] = self.prepare_control_image( + image=control_image[idx], + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = control_image[idx].shape[-2:] # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -1444,10 +1431,11 @@ def __call__( ) # 7.2 Prepare added time ids & embeddings - for image_type in control_image_list: - if isinstance(control_image_list[image_type], torch.Tensor): - original_size = original_size or control_image_list[image_type].shape[-2:] + original_size = original_size or (height, width) target_size = target_size or (height, width) + for _image in control_image: + if isinstance(_image, torch.Tensor): + original_size = original_size or _image.shape[-2:] if negative_original_size is None: negative_original_size = original_size @@ -1531,8 +1519,9 @@ def __call__( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond=control_image_list, + controlnet_cond=control_image, control_type=control_type, + control_type_idx=control_mode, conditioning_scale=cond_scale, guess_mode=guess_mode, added_cond_kwargs=controlnet_added_cond_kwargs,