From bcdbe9ea32e08a9ebe5567c0a966cea085fdbabb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Tue, 15 Jul 2025 10:22:09 -0400 Subject: [PATCH] img2img fixes --- ...pipeline_controlnet_union_sd_xl_img2img.py | 286 ++++++++++++------ 1 file changed, 194 insertions(+), 92 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py index 82ef4b6391eb..65e2fe661797 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py @@ -19,7 +19,6 @@ import numpy as np import PIL.Image import torch -import torch.nn.functional as F from transformers import ( CLIPImageProcessor, CLIPTextModel, @@ -38,7 +37,13 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel +from ...models import ( + AutoencoderKL, + ControlNetUnionModel, + ImageProjection, + MultiControlNetUnionModel, + UNet2DConditionModel, +) from ...models.attention_processor import ( AttnProcessor2_0, XFormersAttnProcessor, @@ -262,7 +267,9 @@ def __init__( tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, - controlnet: ControlNetUnionModel, + controlnet: Union[ + ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel + ], scheduler: KarrasDiffusionSchedulers, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, @@ -272,8 +279,8 @@ def __init__( ): super().__init__() - if not isinstance(controlnet, ControlNetUnionModel): - raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.") + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetUnionModel(controlnet) self.register_modules( vae=vae, @@ -649,6 +656,7 @@ def check_inputs( controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, + control_mode=None, callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: @@ -722,28 +730,44 @@ def check_inputs( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetUnionModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + # Check `image` - is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( - self.controlnet, torch._dynamo.eval_frame.OptimizedModule - ) - if ( - isinstance(self.controlnet, ControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetModel) - ): - self.check_image(image, prompt, prompt_embeds) - elif ( - isinstance(self.controlnet, ControlNetUnionModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetUnionModel) - ): - self.check_image(image, prompt, prompt_embeds) - else: - assert False + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if isinstance(controlnet, ControlNetUnionModel): + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + elif isinstance(controlnet, MultiControlNetUnionModel): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + elif not all(isinstance(i, list) for i in image): + raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for images_ in image: + for image_ in images_: + self.check_image(image_, prompt, prompt_embeds) if not isinstance(control_guidance_start, (tuple, list)): control_guidance_start = [control_guidance_start] + if isinstance(controlnet, MultiControlNetUnionModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + if not isinstance(control_guidance_end, (tuple, list)): control_guidance_end = [control_guidance_end] @@ -762,6 +786,15 @@ def check_inputs( if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + # Check `control_mode` + if isinstance(controlnet, ControlNetUnionModel): + if max(control_mode) >= controlnet.config.num_control_type: + raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.") + elif isinstance(controlnet, MultiControlNetUnionModel): + for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets): + if max(_control_mode) >= _controlnet.config.num_control_type: + raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.") + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." @@ -1049,7 +1082,7 @@ def __call__( prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, - control_image: PipelineImageInput = None, + control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 0.8, @@ -1074,7 +1107,7 @@ def __call__( guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, - control_mode: Optional[Union[int, List[int]]] = None, + control_mode: Optional[Union[int, List[int], List[List[int]]]] = None, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, @@ -1104,13 +1137,13 @@ def __call__( `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The initial image will be used as the starting point for the image generation process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded again. - control_image (`PipelineImageInput`): - The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If - the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also - be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height - and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in - init, images must be passed as a list such that each element of the list can be correctly batched for - input to a single controlnet. + control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. height (`int`, *optional*, defaults to the size of control_image): The height in pixels of the generated image. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) @@ -1184,16 +1217,21 @@ def __call__( `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): - The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added - to the residual in the original unet. If multiple ControlNets are specified in init, you can set the - corresponding scale as a list. + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. guess_mode (`bool`, *optional*, defaults to `False`): In this mode, the ControlNet encoder will try best to recognize the content of the input image even if you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): - The percentage of total steps at which the controlnet starts applying. + The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): - The percentage of total steps at which the controlnet stops applying. + The percentage of total steps at which the ControlNet stops applying. + control_mode (`int` or `List[int]` or `List[List[int]], *optional*): + The control condition types for the ControlNet. See the ControlNet's model card forinformation on the + available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list + where each ControlNet should have its corresponding control mode list. Should reflect the order of + conditions in control_image original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as @@ -1273,12 +1311,6 @@ def __call__( controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - # align format for control guidance - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): - control_guidance_start = len(control_guidance_end) * [control_guidance_start] - elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): - control_guidance_end = len(control_guidance_start) * [control_guidance_end] - if not isinstance(control_image, list): control_image = [control_image] else: @@ -1287,37 +1319,56 @@ def __call__( if not isinstance(control_mode, list): control_mode = [control_mode] - if len(control_image) != len(control_mode): - raise ValueError("Expected len(control_image) == len(control_type)") + if isinstance(controlnet, MultiControlNetUnionModel): + control_image = [[item] for item in control_image] + control_mode = [[item] for item in control_mode] - num_control_type = controlnet.config.num_control_type + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode) + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + if isinstance(controlnet_conditioning_scale, float): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode) + controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult # 1. Check inputs - control_type = [0 for _ in range(num_control_type)] - for _image, control_idx in zip(control_image, control_mode): - control_type[control_idx] = 1 - self.check_inputs( - prompt, - prompt_2, - _image, - strength, - num_inference_steps, - callback_steps, - negative_prompt, - negative_prompt_2, - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ip_adapter_image, - ip_adapter_image_embeds, - controlnet_conditioning_scale, - control_guidance_start, - control_guidance_end, - callback_on_step_end_tensor_inputs, - ) + self.check_inputs( + prompt, + prompt_2, + control_image, + strength, + num_inference_steps, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + control_mode, + callback_on_step_end_tensor_inputs, + ) - control_type = torch.Tensor(control_type) + if isinstance(controlnet, ControlNetUnionModel): + control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1) + elif isinstance(controlnet, MultiControlNetUnionModel): + control_type = [ + torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1) + for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets) + ] self._guidance_scale = guidance_scale self._clip_skip = clip_skip @@ -1334,7 +1385,11 @@ def __call__( device = self._execution_device - global_pool_conditions = controlnet.config.global_pool_conditions + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetUnionModel) + else controlnet.nets[0].config.global_pool_conditions + ) guess_mode = guess_mode or global_pool_conditions # 3.1. Encode input prompt @@ -1372,22 +1427,55 @@ def __call__( self.do_classifier_free_guidance, ) - # 4. Prepare image and controlnet_conditioning_image + # 4.1 Prepare image image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - for idx, _ in enumerate(control_image): - control_image[idx] = self.prepare_control_image( - image=control_image[idx], - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=self.do_classifier_free_guidance, - guess_mode=guess_mode, - ) - height, width = control_image[idx].shape[-2:] + # 4.2 Prepare control images + if isinstance(controlnet, ControlNetUnionModel): + control_images = [] + + for image_ in control_image: + image_ = self.prepare_control_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(image_) + + control_image = control_images + height, width = control_image[0].shape[-2:] + + elif isinstance(controlnet, MultiControlNetUnionModel): + control_images = [] + + for control_image_ in control_image: + images = [] + + for image_ in control_image_: + image_ = self.prepare_control_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + control_images.append(images) + + control_image = control_images + height, width = control_image[0][0].shape[-2:] # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -1414,10 +1502,11 @@ def __call__( # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): - controlnet_keep.append( - 1.0 - - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end) - ) + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps) # 7.2 Prepare added time ids & embeddings original_size = original_size or (height, width) @@ -1460,12 +1549,25 @@ def __call__( prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device) - control_type = ( - control_type.reshape(1, -1) - .to(device, dtype=prompt_embeds.dtype) - .repeat(batch_size * num_images_per_prompt * 2, 1) + + control_type_repeat_factor = ( + batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1) ) + if isinstance(controlnet, ControlNetUnionModel): + control_type = ( + control_type.reshape(1, -1) + .to(self._execution_device, dtype=prompt_embeds.dtype) + .repeat(control_type_repeat_factor, 1) + ) + elif isinstance(controlnet, MultiControlNetUnionModel): + control_type = [ + _control_type.reshape(1, -1) + .to(self._execution_device, dtype=prompt_embeds.dtype) + .repeat(control_type_repeat_factor, 1) + for _control_type in control_type + ] + # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: