1717
1818import torch
1919from transformers import (
20+ BaseImageProcessor ,
2021 CLIPTextModelWithProjection ,
2122 CLIPTokenizer ,
23+ PreTrainedModel ,
2224 T5EncoderModel ,
2325 T5TokenizerFast ,
2426)
2527
2628from ...image_processor import PipelineImageInput , VaeImageProcessor
27- from ...loaders import FromSingleFileMixin , SD3LoraLoaderMixin
29+ from ...loaders import FromSingleFileMixin , SD3IPAdapterMixin , SD3LoraLoaderMixin
2830from ...models .autoencoders import AutoencoderKL
2931from ...models .controlnets .controlnet_sd3 import SD3ControlNetModel , SD3MultiControlNetModel
3032from ...models .transformers import SD3Transformer2DModel
@@ -159,7 +161,9 @@ def retrieve_timesteps(
159161 return timesteps , num_inference_steps
160162
161163
162- class StableDiffusion3ControlNetInpaintingPipeline (DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin ):
164+ class StableDiffusion3ControlNetInpaintingPipeline (
165+ DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin , SD3IPAdapterMixin
166+ ):
163167 r"""
164168 Args:
165169 transformer ([`SD3Transformer2DModel`]):
@@ -192,13 +196,17 @@ class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoa
192196 Tokenizer of class
193197 [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
194198 controlnet ([`SD3ControlNetModel`] or `List[SD3ControlNetModel]` or [`SD3MultiControlNetModel`]):
195- Provides additional conditioning to the `unet ` during the denoising process. If you set multiple
199+ Provides additional conditioning to the `transformer ` during the denoising process. If you set multiple
196200 ControlNets as a list, the outputs from each ControlNet are added together to create one combined
197201 additional conditioning.
202+ image_encoder (`PreTrainedModel`, *optional*):
203+ Pre-trained Vision Model for IP Adapter.
204+ feature_extractor (`BaseImageProcessor`, *optional*):
205+ Image processor for IP Adapter.
198206 """
199207
200- model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
201- _optional_components = []
208+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder-> transformer->vae"
209+ _optional_components = ["image_encoder" , "feature_extractor" ]
202210 _callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" , "negative_pooled_prompt_embeds" ]
203211
204212 def __init__ (
@@ -215,6 +223,8 @@ def __init__(
215223 controlnet : Union [
216224 SD3ControlNetModel , List [SD3ControlNetModel ], Tuple [SD3ControlNetModel ], SD3MultiControlNetModel
217225 ],
226+ image_encoder : PreTrainedModel = None ,
227+ feature_extractor : BaseImageProcessor = None ,
218228 ):
219229 super ().__init__ ()
220230
@@ -229,6 +239,8 @@ def __init__(
229239 transformer = transformer ,
230240 scheduler = scheduler ,
231241 controlnet = controlnet ,
242+ image_encoder = image_encoder ,
243+ feature_extractor = feature_extractor ,
232244 )
233245 self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 ) if getattr (self , "vae" , None ) else 8
234246 self .image_processor = VaeImageProcessor (
@@ -775,6 +787,84 @@ def num_timesteps(self):
775787 def interrupt (self ):
776788 return self ._interrupt
777789
790+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
791+ def encode_image (self , image : PipelineImageInput , device : torch .device ) -> torch .Tensor :
792+ """Encodes the given image into a feature representation using a pre-trained image encoder.
793+
794+ Args:
795+ image (`PipelineImageInput`):
796+ Input image to be encoded.
797+ device: (`torch.device`):
798+ Torch device.
799+
800+ Returns:
801+ `torch.Tensor`: The encoded image feature representation.
802+ """
803+ if not isinstance (image , torch .Tensor ):
804+ image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
805+
806+ image = image .to (device = device , dtype = self .dtype )
807+
808+ return self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
809+
810+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
811+ def prepare_ip_adapter_image_embeds (
812+ self ,
813+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
814+ ip_adapter_image_embeds : Optional [torch .Tensor ] = None ,
815+ device : Optional [torch .device ] = None ,
816+ num_images_per_prompt : int = 1 ,
817+ do_classifier_free_guidance : bool = True ,
818+ ) -> torch .Tensor :
819+ """Prepares image embeddings for use in the IP-Adapter.
820+
821+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
822+
823+ Args:
824+ ip_adapter_image (`PipelineImageInput`, *optional*):
825+ The input image to extract features from for IP-Adapter.
826+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
827+ Precomputed image embeddings.
828+ device: (`torch.device`, *optional*):
829+ Torch device.
830+ num_images_per_prompt (`int`, defaults to 1):
831+ Number of images that should be generated per prompt.
832+ do_classifier_free_guidance (`bool`, defaults to True):
833+ Whether to use classifier free guidance or not.
834+ """
835+ device = device or self ._execution_device
836+
837+ if ip_adapter_image_embeds is not None :
838+ if do_classifier_free_guidance :
839+ single_negative_image_embeds , single_image_embeds = ip_adapter_image_embeds .chunk (2 )
840+ else :
841+ single_image_embeds = ip_adapter_image_embeds
842+ elif ip_adapter_image is not None :
843+ single_image_embeds = self .encode_image (ip_adapter_image , device )
844+ if do_classifier_free_guidance :
845+ single_negative_image_embeds = torch .zeros_like (single_image_embeds )
846+ else :
847+ raise ValueError ("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided." )
848+
849+ image_embeds = torch .cat ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
850+
851+ if do_classifier_free_guidance :
852+ negative_image_embeds = torch .cat ([single_negative_image_embeds ] * num_images_per_prompt , dim = 0 )
853+ image_embeds = torch .cat ([negative_image_embeds , image_embeds ], dim = 0 )
854+
855+ return image_embeds .to (device = device )
856+
857+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
858+ def enable_sequential_cpu_offload (self , * args , ** kwargs ):
859+ if self .image_encoder is not None and "image_encoder" not in self ._exclude_from_cpu_offload :
860+ logger .warning (
861+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
862+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
863+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
864+ )
865+
866+ super ().enable_sequential_cpu_offload (* args , ** kwargs )
867+
778868 @torch .no_grad ()
779869 @replace_example_docstring (EXAMPLE_DOC_STRING )
780870 def __call__ (
@@ -803,6 +893,8 @@ def __call__(
803893 negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
804894 pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
805895 negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
896+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
897+ ip_adapter_image_embeds : Optional [torch .Tensor ] = None ,
806898 output_type : Optional [str ] = "pil" ,
807899 return_dict : bool = True ,
808900 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -896,6 +988,12 @@ def __call__(
896988 Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
897989 weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
898990 input argument.
991+ ip_adapter_image (`PipelineImageInput`, *optional*):
992+ Optional image input to work with IP Adapters.
993+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
994+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
995+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
996+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
899997 output_type (`str`, *optional*, defaults to `"pil"`):
900998 The output format of the generate image. Choose between
901999 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -1057,7 +1155,22 @@ def __call__(
10571155 ]
10581156 controlnet_keep .append (keeps [0 ] if isinstance (self .controlnet , SD3ControlNetModel ) else keeps )
10591157
1060- # 7. Denoising loop
1158+ # 7. Prepare image embeddings
1159+ if (ip_adapter_image is not None and self .is_ip_adapter_active ) or ip_adapter_image_embeds is not None :
1160+ ip_adapter_image_embeds = self .prepare_ip_adapter_image_embeds (
1161+ ip_adapter_image ,
1162+ ip_adapter_image_embeds ,
1163+ device ,
1164+ batch_size * num_images_per_prompt ,
1165+ self .do_classifier_free_guidance ,
1166+ )
1167+
1168+ if self .joint_attention_kwargs is None :
1169+ self ._joint_attention_kwargs = {"ip_adapter_image_embeds" : ip_adapter_image_embeds }
1170+ else :
1171+ self ._joint_attention_kwargs .update (ip_adapter_image_embeds = ip_adapter_image_embeds )
1172+
1173+ # 8. Denoising loop
10611174 with self .progress_bar (total = num_inference_steps ) as progress_bar :
10621175 for i , t in enumerate (timesteps ):
10631176 if self .interrupt :
0 commit comments