1313# limitations under the License.
1414
1515import inspect
16- from typing import Callable , Dict , List , Optional , Union
16+ from typing import Any , Callable , Dict , List , Optional , Union
1717
1818import torch
1919from transformers import (
20+ BaseImageProcessor ,
2021 CLIPTextModelWithProjection ,
2122 CLIPTokenizer ,
23+ PreTrainedModel ,
2224 T5EncoderModel ,
2325 T5TokenizerFast ,
2426)
2527
2628from ...callbacks import MultiPipelineCallbacks , PipelineCallback
2729from ...image_processor import PipelineImageInput , VaeImageProcessor
28- from ...loaders import FromSingleFileMixin , SD3LoraLoaderMixin
30+ from ...loaders import FromSingleFileMixin , SD3IPAdapterMixin , SD3LoraLoaderMixin
2931from ...models .autoencoders import AutoencoderKL
3032from ...models .transformers import SD3Transformer2DModel
3133from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -162,7 +164,7 @@ def retrieve_timesteps(
162164 return timesteps , num_inference_steps
163165
164166
165- class StableDiffusion3InpaintPipeline (DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin ):
167+ class StableDiffusion3InpaintPipeline (DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin , SD3IPAdapterMixin ):
166168 r"""
167169 Args:
168170 transformer ([`SD3Transformer2DModel`]):
@@ -194,10 +196,14 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
194196 tokenizer_3 (`T5TokenizerFast`):
195197 Tokenizer of class
196198 [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
199+ image_encoder (`PreTrainedModel`, *optional*):
200+ Pre-trained Vision Model for IP Adapter.
201+ feature_extractor (`BaseImageProcessor`, *optional*):
202+ Image processor for IP Adapter.
197203 """
198204
199- model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
200- _optional_components = []
205+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder-> transformer->vae"
206+ _optional_components = ["image_encoder" , "feature_extractor" ]
201207 _callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" , "negative_pooled_prompt_embeds" ]
202208
203209 def __init__ (
@@ -211,6 +217,8 @@ def __init__(
211217 tokenizer_2 : CLIPTokenizer ,
212218 text_encoder_3 : T5EncoderModel ,
213219 tokenizer_3 : T5TokenizerFast ,
220+ image_encoder : PreTrainedModel = None ,
221+ feature_extractor : BaseImageProcessor = None ,
214222 ):
215223 super ().__init__ ()
216224
@@ -224,6 +232,8 @@ def __init__(
224232 tokenizer_3 = tokenizer_3 ,
225233 transformer = transformer ,
226234 scheduler = scheduler ,
235+ image_encoder = image_encoder ,
236+ feature_extractor = feature_extractor ,
227237 )
228238 self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 ) if getattr (self , "vae" , None ) else 8
229239 latent_channels = self .vae .config .latent_channels if getattr (self , "vae" , None ) else 16
@@ -818,6 +828,10 @@ def clip_skip(self):
818828 def do_classifier_free_guidance (self ):
819829 return self ._guidance_scale > 1
820830
831+ @property
832+ def joint_attention_kwargs (self ):
833+ return self ._joint_attention_kwargs
834+
821835 @property
822836 def num_timesteps (self ):
823837 return self ._num_timesteps
@@ -826,6 +840,84 @@ def num_timesteps(self):
826840 def interrupt (self ):
827841 return self ._interrupt
828842
843+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
844+ def encode_image (self , image : PipelineImageInput , device : torch .device ) -> torch .Tensor :
845+ """Encodes the given image into a feature representation using a pre-trained image encoder.
846+
847+ Args:
848+ image (`PipelineImageInput`):
849+ Input image to be encoded.
850+ device: (`torch.device`):
851+ Torch device.
852+
853+ Returns:
854+ `torch.Tensor`: The encoded image feature representation.
855+ """
856+ if not isinstance (image , torch .Tensor ):
857+ image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
858+
859+ image = image .to (device = device , dtype = self .dtype )
860+
861+ return self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
862+
863+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
864+ def prepare_ip_adapter_image_embeds (
865+ self ,
866+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
867+ ip_adapter_image_embeds : Optional [torch .Tensor ] = None ,
868+ device : Optional [torch .device ] = None ,
869+ num_images_per_prompt : int = 1 ,
870+ do_classifier_free_guidance : bool = True ,
871+ ) -> torch .Tensor :
872+ """Prepares image embeddings for use in the IP-Adapter.
873+
874+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
875+
876+ Args:
877+ ip_adapter_image (`PipelineImageInput`, *optional*):
878+ The input image to extract features from for IP-Adapter.
879+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
880+ Precomputed image embeddings.
881+ device: (`torch.device`, *optional*):
882+ Torch device.
883+ num_images_per_prompt (`int`, defaults to 1):
884+ Number of images that should be generated per prompt.
885+ do_classifier_free_guidance (`bool`, defaults to True):
886+ Whether to use classifier free guidance or not.
887+ """
888+ device = device or self ._execution_device
889+
890+ if ip_adapter_image_embeds is not None :
891+ if do_classifier_free_guidance :
892+ single_negative_image_embeds , single_image_embeds = ip_adapter_image_embeds .chunk (2 )
893+ else :
894+ single_image_embeds = ip_adapter_image_embeds
895+ elif ip_adapter_image is not None :
896+ single_image_embeds = self .encode_image (ip_adapter_image , device )
897+ if do_classifier_free_guidance :
898+ single_negative_image_embeds = torch .zeros_like (single_image_embeds )
899+ else :
900+ raise ValueError ("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided." )
901+
902+ image_embeds = torch .cat ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
903+
904+ if do_classifier_free_guidance :
905+ negative_image_embeds = torch .cat ([single_negative_image_embeds ] * num_images_per_prompt , dim = 0 )
906+ image_embeds = torch .cat ([negative_image_embeds , image_embeds ], dim = 0 )
907+
908+ return image_embeds .to (device = device )
909+
910+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
911+ def enable_sequential_cpu_offload (self , * args , ** kwargs ):
912+ if self .image_encoder is not None and "image_encoder" not in self ._exclude_from_cpu_offload :
913+ logger .warning (
914+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
915+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
916+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
917+ )
918+
919+ super ().enable_sequential_cpu_offload (* args , ** kwargs )
920+
829921 @torch .no_grad ()
830922 @replace_example_docstring (EXAMPLE_DOC_STRING )
831923 def __call__ (
@@ -853,8 +945,11 @@ def __call__(
853945 negative_prompt_embeds : Optional [torch .Tensor ] = None ,
854946 pooled_prompt_embeds : Optional [torch .Tensor ] = None ,
855947 negative_pooled_prompt_embeds : Optional [torch .Tensor ] = None ,
948+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
949+ ip_adapter_image_embeds : Optional [torch .Tensor ] = None ,
856950 output_type : Optional [str ] = "pil" ,
857951 return_dict : bool = True ,
952+ joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
858953 clip_skip : Optional [int ] = None ,
859954 callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
860955 callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
@@ -890,9 +985,9 @@ def __call__(
890985 mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
891986 `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
892987 latents tensor will ge generated by `mask_image`.
893- height (`int`, *optional*, defaults to self.unet .config.sample_size * self.vae_scale_factor):
988+ height (`int`, *optional*, defaults to self.transformer .config.sample_size * self.vae_scale_factor):
894989 The height in pixels of the generated image. This is set to 1024 by default for the best results.
895- width (`int`, *optional*, defaults to self.unet .config.sample_size * self.vae_scale_factor):
990+ width (`int`, *optional*, defaults to self.transformer .config.sample_size * self.vae_scale_factor):
896991 The width in pixels of the generated image. This is set to 1024 by default for the best results.
897992 padding_mask_crop (`int`, *optional*, defaults to `None`):
898993 The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
@@ -953,12 +1048,22 @@ def __call__(
9531048 Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
9541049 weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
9551050 input argument.
1051+ ip_adapter_image (`PipelineImageInput`, *optional*):
1052+ Optional image input to work with IP Adapters.
1053+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
1054+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
1055+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
1056+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
9561057 output_type (`str`, *optional*, defaults to `"pil"`):
9571058 The output format of the generate image. Choose between
9581059 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
9591060 return_dict (`bool`, *optional*, defaults to `True`):
9601061 Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
9611062 a plain tuple.
1063+ joint_attention_kwargs (`dict`, *optional*):
1064+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1065+ `self.processor` in
1066+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
9621067 callback_on_step_end (`Callable`, *optional*):
9631068 A function that calls at the end of each denoising steps during the inference. The function is called
9641069 with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -1006,6 +1111,7 @@ def __call__(
10061111
10071112 self ._guidance_scale = guidance_scale
10081113 self ._clip_skip = clip_skip
1114+ self ._joint_attention_kwargs = joint_attention_kwargs
10091115 self ._interrupt = False
10101116
10111117 # 2. Define call parameters
@@ -1160,7 +1266,22 @@ def __call__(
11601266 f"The transformer { self .transformer .__class__ } should have 16 input channels or 33 input channels, not { self .transformer .config .in_channels } ."
11611267 )
11621268
1163- # 7. Denoising loop
1269+ # 7. Prepare image embeddings
1270+ if (ip_adapter_image is not None and self .is_ip_adapter_active ) or ip_adapter_image_embeds is not None :
1271+ ip_adapter_image_embeds = self .prepare_ip_adapter_image_embeds (
1272+ ip_adapter_image ,
1273+ ip_adapter_image_embeds ,
1274+ device ,
1275+ batch_size * num_images_per_prompt ,
1276+ self .do_classifier_free_guidance ,
1277+ )
1278+
1279+ if self .joint_attention_kwargs is None :
1280+ self ._joint_attention_kwargs = {"ip_adapter_image_embeds" : ip_adapter_image_embeds }
1281+ else :
1282+ self ._joint_attention_kwargs .update (ip_adapter_image_embeds = ip_adapter_image_embeds )
1283+
1284+ # 8. Denoising loop
11641285 num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
11651286 self ._num_timesteps = len (timesteps )
11661287 with self .progress_bar (total = num_inference_steps ) as progress_bar :
@@ -1181,6 +1302,7 @@ def __call__(
11811302 timestep = timestep ,
11821303 encoder_hidden_states = prompt_embeds ,
11831304 pooled_projections = pooled_prompt_embeds ,
1305+ joint_attention_kwargs = self .joint_attention_kwargs ,
11841306 return_dict = False ,
11851307 )[0 ]
11861308
0 commit comments