1717
1818import torch
1919from transformers import (
20+ BaseImageProcessor ,
2021 CLIPTextModelWithProjection ,
2122 CLIPTokenizer ,
23+ PreTrainedModel ,
2224 T5EncoderModel ,
2325 T5TokenizerFast ,
2426)
2527
2628from ...callbacks import MultiPipelineCallbacks , PipelineCallback
2729from ...image_processor import PipelineImageInput , VaeImageProcessor
28- from ...loaders import FromSingleFileMixin , SD3LoraLoaderMixin
30+ from ...loaders import FromSingleFileMixin , SD3IPAdapterMixin , SD3LoraLoaderMixin
2931from ...models .autoencoders import AutoencoderKL
3032from ...models .transformers import SD3Transformer2DModel
3133from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -162,7 +164,7 @@ def retrieve_timesteps(
162164 return timesteps , num_inference_steps
163165
164166
165- class StableDiffusion3InpaintPipeline (DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin ):
167+ class StableDiffusion3InpaintPipeline (DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin , SD3IPAdapterMixin ):
166168 r"""
167169 Args:
168170 transformer ([`SD3Transformer2DModel`]):
@@ -194,10 +196,14 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
194196 tokenizer_3 (`T5TokenizerFast`):
195197 Tokenizer of class
196198 [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
199+ image_encoder (`PreTrainedModel`, *optional*):
200+ Pre-trained Vision Model for IP Adapter.
201+ feature_extractor (`BaseImageProcessor`, *optional*):
202+ Image processor for IP Adapter.
197203 """
198204
199- model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
200- _optional_components = []
205+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder-> transformer->vae"
206+ _optional_components = ["image_encoder" , "feature_extractor" ]
201207 _callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" , "negative_pooled_prompt_embeds" ]
202208
203209 def __init__ (
@@ -211,6 +217,8 @@ def __init__(
211217 tokenizer_2 : CLIPTokenizer ,
212218 text_encoder_3 : T5EncoderModel ,
213219 tokenizer_3 : T5TokenizerFast ,
220+ image_encoder : PreTrainedModel = None ,
221+ feature_extractor : BaseImageProcessor = None ,
214222 ):
215223 super ().__init__ ()
216224
@@ -224,6 +232,8 @@ def __init__(
224232 tokenizer_3 = tokenizer_3 ,
225233 transformer = transformer ,
226234 scheduler = scheduler ,
235+ image_encoder = image_encoder ,
236+ feature_extractor = feature_extractor ,
227237 )
228238 self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 ) if getattr (self , "vae" , None ) else 8
229239 latent_channels = self .vae .config .latent_channels if getattr (self , "vae" , None ) else 16
@@ -826,6 +836,84 @@ def num_timesteps(self):
826836 def interrupt (self ):
827837 return self ._interrupt
828838
839+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
840+ def encode_image (self , image : PipelineImageInput , device : torch .device ) -> torch .Tensor :
841+ """Encodes the given image into a feature representation using a pre-trained image encoder.
842+
843+ Args:
844+ image (`PipelineImageInput`):
845+ Input image to be encoded.
846+ device: (`torch.device`):
847+ Torch device.
848+
849+ Returns:
850+ `torch.Tensor`: The encoded image feature representation.
851+ """
852+ if not isinstance (image , torch .Tensor ):
853+ image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
854+
855+ image = image .to (device = device , dtype = self .dtype )
856+
857+ return self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
858+
859+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
860+ def prepare_ip_adapter_image_embeds (
861+ self ,
862+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
863+ ip_adapter_image_embeds : Optional [torch .Tensor ] = None ,
864+ device : Optional [torch .device ] = None ,
865+ num_images_per_prompt : int = 1 ,
866+ do_classifier_free_guidance : bool = True ,
867+ ) -> torch .Tensor :
868+ """Prepares image embeddings for use in the IP-Adapter.
869+
870+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
871+
872+ Args:
873+ ip_adapter_image (`PipelineImageInput`, *optional*):
874+ The input image to extract features from for IP-Adapter.
875+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
876+ Precomputed image embeddings.
877+ device: (`torch.device`, *optional*):
878+ Torch device.
879+ num_images_per_prompt (`int`, defaults to 1):
880+ Number of images that should be generated per prompt.
881+ do_classifier_free_guidance (`bool`, defaults to True):
882+ Whether to use classifier free guidance or not.
883+ """
884+ device = device or self ._execution_device
885+
886+ if ip_adapter_image_embeds is not None :
887+ if do_classifier_free_guidance :
888+ single_negative_image_embeds , single_image_embeds = ip_adapter_image_embeds .chunk (2 )
889+ else :
890+ single_image_embeds = ip_adapter_image_embeds
891+ elif ip_adapter_image is not None :
892+ single_image_embeds = self .encode_image (ip_adapter_image , device )
893+ if do_classifier_free_guidance :
894+ single_negative_image_embeds = torch .zeros_like (single_image_embeds )
895+ else :
896+ raise ValueError ("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided." )
897+
898+ image_embeds = torch .cat ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
899+
900+ if do_classifier_free_guidance :
901+ negative_image_embeds = torch .cat ([single_negative_image_embeds ] * num_images_per_prompt , dim = 0 )
902+ image_embeds = torch .cat ([negative_image_embeds , image_embeds ], dim = 0 )
903+
904+ return image_embeds .to (device = device )
905+
906+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
907+ def enable_sequential_cpu_offload (self , * args , ** kwargs ):
908+ if self .image_encoder is not None and "image_encoder" not in self ._exclude_from_cpu_offload :
909+ logger .warning (
910+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
911+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
912+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
913+ )
914+
915+ super ().enable_sequential_cpu_offload (* args , ** kwargs )
916+
829917 @torch .no_grad ()
830918 @replace_example_docstring (EXAMPLE_DOC_STRING )
831919 def __call__ (
@@ -853,6 +941,8 @@ def __call__(
853941 negative_prompt_embeds : Optional [torch .Tensor ] = None ,
854942 pooled_prompt_embeds : Optional [torch .Tensor ] = None ,
855943 negative_pooled_prompt_embeds : Optional [torch .Tensor ] = None ,
944+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
945+ ip_adapter_image_embeds : Optional [torch .Tensor ] = None ,
856946 output_type : Optional [str ] = "pil" ,
857947 return_dict : bool = True ,
858948 clip_skip : Optional [int ] = None ,
@@ -890,9 +980,9 @@ def __call__(
890980 mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
891981 `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
892982 latents tensor will ge generated by `mask_image`.
893- height (`int`, *optional*, defaults to self.unet .config.sample_size * self.vae_scale_factor):
983+ height (`int`, *optional*, defaults to self.transformer .config.sample_size * self.vae_scale_factor):
894984 The height in pixels of the generated image. This is set to 1024 by default for the best results.
895- width (`int`, *optional*, defaults to self.unet .config.sample_size * self.vae_scale_factor):
985+ width (`int`, *optional*, defaults to self.transformer .config.sample_size * self.vae_scale_factor):
896986 The width in pixels of the generated image. This is set to 1024 by default for the best results.
897987 padding_mask_crop (`int`, *optional*, defaults to `None`):
898988 The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
@@ -953,6 +1043,12 @@ def __call__(
9531043 Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
9541044 weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
9551045 input argument.
1046+ ip_adapter_image (`PipelineImageInput`, *optional*):
1047+ Optional image input to work with IP Adapters.
1048+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
1049+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
1050+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
1051+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
9561052 output_type (`str`, *optional*, defaults to `"pil"`):
9571053 The output format of the generate image. Choose between
9581054 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -1160,7 +1256,22 @@ def __call__(
11601256 f"The transformer { self .transformer .__class__ } should have 16 input channels or 33 input channels, not { self .transformer .config .in_channels } ."
11611257 )
11621258
1163- # 7. Denoising loop
1259+ # 7. Prepare image embeddings
1260+ if (ip_adapter_image is not None and self .is_ip_adapter_active ) or ip_adapter_image_embeds is not None :
1261+ ip_adapter_image_embeds = self .prepare_ip_adapter_image_embeds (
1262+ ip_adapter_image ,
1263+ ip_adapter_image_embeds ,
1264+ device ,
1265+ batch_size * num_images_per_prompt ,
1266+ self .do_classifier_free_guidance ,
1267+ )
1268+
1269+ if self .joint_attention_kwargs is None :
1270+ self ._joint_attention_kwargs = {"ip_adapter_image_embeds" : ip_adapter_image_embeds }
1271+ else :
1272+ self ._joint_attention_kwargs .update (ip_adapter_image_embeds = ip_adapter_image_embeds )
1273+
1274+ # 8. Denoising loop
11641275 num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
11651276 self ._num_timesteps = len (timesteps )
11661277 with self .progress_bar (total = num_inference_steps ) as progress_bar :
0 commit comments