diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index 1de7ba424d54..4e135f9391dd 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -17,14 +17,16 @@ import torch from transformers import ( + BaseImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, + PreTrainedModel, T5EncoderModel, T5TokenizerFast, ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin +from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel from ...models.transformers import SD3Transformer2DModel @@ -138,7 +140,9 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin): +class StableDiffusion3ControlNetPipeline( + DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin +): r""" Args: transformer ([`SD3Transformer2DModel`]): @@ -174,10 +178,14 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Provides additional conditioning to the `unet` during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. + image_encoder (`PreTrainedModel`, *optional*): + Pre-trained Vision Model for IP Adapter. + feature_extractor (`BaseImageProcessor`, *optional*): + Image processor for IP Adapter. """ - model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" - _optional_components = [] + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] def __init__( @@ -194,6 +202,8 @@ def __init__( controlnet: Union[ SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel ], + image_encoder: PreTrainedModel = None, + feature_extractor: BaseImageProcessor = None, ): super().__init__() if isinstance(controlnet, (list, tuple)): @@ -223,6 +233,8 @@ def __init__( transformer=transformer, scheduler=scheduler, controlnet=controlnet, + image_encoder=image_encoder, + feature_extractor=feature_extractor, ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 @@ -727,6 +739,84 @@ def num_timesteps(self): def interrupt(self): return self._interrupt + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image + def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor: + """Encodes the given image into a feature representation using a pre-trained image encoder. + + Args: + image (`PipelineImageInput`): + Input image to be encoded. + device: (`torch.device`): + Torch device. + + Returns: + `torch.Tensor`: The encoded image feature representation. + """ + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=self.dtype) + + return self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + ) -> torch.Tensor: + """Prepares image embeddings for use in the IP-Adapter. + + Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed. + + Args: + ip_adapter_image (`PipelineImageInput`, *optional*): + The input image to extract features from for IP-Adapter. + ip_adapter_image_embeds (`torch.Tensor`, *optional*): + Precomputed image embeddings. + device: (`torch.device`, *optional*): + Torch device. + num_images_per_prompt (`int`, defaults to 1): + Number of images that should be generated per prompt. + do_classifier_free_guidance (`bool`, defaults to True): + Whether to use classifier free guidance or not. + """ + device = device or self._execution_device + + if ip_adapter_image_embeds is not None: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2) + else: + single_image_embeds = ip_adapter_image_embeds + elif ip_adapter_image is not None: + single_image_embeds = self.encode_image(ip_adapter_image, device) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.zeros_like(single_image_embeds) + else: + raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.") + + image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0) + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + + return image_embeds.to(device=device) + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, *args, **kwargs): + if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload: + logger.warning( + "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses " + "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling " + "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`." + ) + + super().enable_sequential_cpu_offload(*args, **kwargs) + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -754,6 +844,8 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -843,6 +935,12 @@ def __call__( Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. + ip_adapter_image (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images, + emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to + `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -1040,7 +1138,22 @@ def __call__( # SD35 official 8b controlnet does not use encoder_hidden_states controlnet_encoder_hidden_states = None - # 7. Denoising loop + # 7. Prepare image embeddings + if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds} + else: + self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds) + + # 8. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index a53d786798ca..4ec0eb829b69 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -870,7 +870,8 @@ def __call__( Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. - ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. ip_adapter_image_embeds (`torch.Tensor`, *optional*): Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py index 5c547164c29a..7527d17af32a 100644 --- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py +++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py @@ -150,6 +150,8 @@ def get_dummy_components( "transformer": transformer, "vae": vae, "controlnet": controlnet, + "image_encoder": None, + "feature_extractor": None, } def get_dummy_inputs(self, device, seed=0):