1313# limitations under the License. 
1414
1515import  inspect 
16- from  typing  import  Callable , Dict , List , Optional , Union 
16+ from  typing  import  Any ,  Callable , Dict , List , Optional , Union 
1717
1818import  torch 
1919from  transformers  import  (
20+     BaseImageProcessor ,
2021    CLIPTextModelWithProjection ,
2122    CLIPTokenizer ,
23+     PreTrainedModel ,
2224    T5EncoderModel ,
2325    T5TokenizerFast ,
2426)
2527
2628from  ...callbacks  import  MultiPipelineCallbacks , PipelineCallback 
2729from  ...image_processor  import  PipelineImageInput , VaeImageProcessor 
28- from  ...loaders  import  FromSingleFileMixin , SD3LoraLoaderMixin 
30+ from  ...loaders  import  FromSingleFileMixin , SD3IPAdapterMixin ,  SD3LoraLoaderMixin 
2931from  ...models .autoencoders  import  AutoencoderKL 
3032from  ...models .transformers  import  SD3Transformer2DModel 
3133from  ...schedulers  import  FlowMatchEulerDiscreteScheduler 
@@ -162,7 +164,7 @@ def retrieve_timesteps(
162164    return  timesteps , num_inference_steps 
163165
164166
165- class  StableDiffusion3InpaintPipeline (DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin ):
167+ class  StableDiffusion3InpaintPipeline (DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin ,  SD3IPAdapterMixin ):
166168    r""" 
167169    Args: 
168170        transformer ([`SD3Transformer2DModel`]): 
@@ -194,10 +196,14 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
194196        tokenizer_3 (`T5TokenizerFast`): 
195197            Tokenizer of class 
196198            [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). 
199+         image_encoder (`PreTrainedModel`, *optional*): 
200+             Pre-trained Vision Model for IP Adapter. 
201+         feature_extractor (`BaseImageProcessor`, *optional*): 
202+             Image processor for IP Adapter. 
197203    """ 
198204
199-     model_cpu_offload_seq  =  "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" 
200-     _optional_components  =  []
205+     model_cpu_offload_seq  =  "text_encoder->text_encoder_2->text_encoder_3->image_encoder-> transformer->vae" 
206+     _optional_components  =  ["image_encoder" ,  "feature_extractor" ]
201207    _callback_tensor_inputs  =  ["latents" , "prompt_embeds" , "negative_prompt_embeds" , "negative_pooled_prompt_embeds" ]
202208
203209    def  __init__ (
@@ -211,6 +217,8 @@ def __init__(
211217        tokenizer_2 : CLIPTokenizer ,
212218        text_encoder_3 : T5EncoderModel ,
213219        tokenizer_3 : T5TokenizerFast ,
220+         image_encoder : PreTrainedModel  =  None ,
221+         feature_extractor : BaseImageProcessor  =  None ,
214222    ):
215223        super ().__init__ ()
216224
@@ -224,6 +232,8 @@ def __init__(
224232            tokenizer_3 = tokenizer_3 ,
225233            transformer = transformer ,
226234            scheduler = scheduler ,
235+             image_encoder = image_encoder ,
236+             feature_extractor = feature_extractor ,
227237        )
228238        self .vae_scale_factor  =  2  **  (len (self .vae .config .block_out_channels ) -  1 ) if  getattr (self , "vae" , None ) else  8 
229239        latent_channels  =  self .vae .config .latent_channels  if  getattr (self , "vae" , None ) else  16 
@@ -818,6 +828,10 @@ def clip_skip(self):
818828    def  do_classifier_free_guidance (self ):
819829        return  self ._guidance_scale  >  1 
820830
831+     @property  
832+     def  joint_attention_kwargs (self ):
833+         return  self ._joint_attention_kwargs 
834+ 
821835    @property  
822836    def  num_timesteps (self ):
823837        return  self ._num_timesteps 
@@ -826,6 +840,84 @@ def num_timesteps(self):
826840    def  interrupt (self ):
827841        return  self ._interrupt 
828842
843+     # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image 
844+     def  encode_image (self , image : PipelineImageInput , device : torch .device ) ->  torch .Tensor :
845+         """Encodes the given image into a feature representation using a pre-trained image encoder. 
846+ 
847+         Args: 
848+             image (`PipelineImageInput`): 
849+                 Input image to be encoded. 
850+             device: (`torch.device`): 
851+                 Torch device. 
852+ 
853+         Returns: 
854+             `torch.Tensor`: The encoded image feature representation. 
855+         """ 
856+         if  not  isinstance (image , torch .Tensor ):
857+             image  =  self .feature_extractor (image , return_tensors = "pt" ).pixel_values 
858+ 
859+         image  =  image .to (device = device , dtype = self .dtype )
860+ 
861+         return  self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
862+ 
863+     # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds 
864+     def  prepare_ip_adapter_image_embeds (
865+         self ,
866+         ip_adapter_image : Optional [PipelineImageInput ] =  None ,
867+         ip_adapter_image_embeds : Optional [torch .Tensor ] =  None ,
868+         device : Optional [torch .device ] =  None ,
869+         num_images_per_prompt : int  =  1 ,
870+         do_classifier_free_guidance : bool  =  True ,
871+     ) ->  torch .Tensor :
872+         """Prepares image embeddings for use in the IP-Adapter. 
873+ 
874+         Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed. 
875+ 
876+         Args: 
877+             ip_adapter_image (`PipelineImageInput`, *optional*): 
878+                 The input image to extract features from for IP-Adapter. 
879+             ip_adapter_image_embeds (`torch.Tensor`, *optional*): 
880+                 Precomputed image embeddings. 
881+             device: (`torch.device`, *optional*): 
882+                 Torch device. 
883+             num_images_per_prompt (`int`, defaults to 1): 
884+                 Number of images that should be generated per prompt. 
885+             do_classifier_free_guidance (`bool`, defaults to True): 
886+                 Whether to use classifier free guidance or not. 
887+         """ 
888+         device  =  device  or  self ._execution_device 
889+ 
890+         if  ip_adapter_image_embeds  is  not None :
891+             if  do_classifier_free_guidance :
892+                 single_negative_image_embeds , single_image_embeds  =  ip_adapter_image_embeds .chunk (2 )
893+             else :
894+                 single_image_embeds  =  ip_adapter_image_embeds 
895+         elif  ip_adapter_image  is  not None :
896+             single_image_embeds  =  self .encode_image (ip_adapter_image , device )
897+             if  do_classifier_free_guidance :
898+                 single_negative_image_embeds  =  torch .zeros_like (single_image_embeds )
899+         else :
900+             raise  ValueError ("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided." )
901+ 
902+         image_embeds  =  torch .cat ([single_image_embeds ] *  num_images_per_prompt , dim = 0 )
903+ 
904+         if  do_classifier_free_guidance :
905+             negative_image_embeds  =  torch .cat ([single_negative_image_embeds ] *  num_images_per_prompt , dim = 0 )
906+             image_embeds  =  torch .cat ([negative_image_embeds , image_embeds ], dim = 0 )
907+ 
908+         return  image_embeds .to (device = device )
909+ 
910+     # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload 
911+     def  enable_sequential_cpu_offload (self , * args , ** kwargs ):
912+         if  self .image_encoder  is  not None  and  "image_encoder"  not  in self ._exclude_from_cpu_offload :
913+             logger .warning (
914+                 "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses " 
915+                 "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling " 
916+                 "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`." 
917+             )
918+ 
919+         super ().enable_sequential_cpu_offload (* args , ** kwargs )
920+ 
829921    @torch .no_grad () 
830922    @replace_example_docstring (EXAMPLE_DOC_STRING ) 
831923    def  __call__ (
@@ -853,8 +945,11 @@ def __call__(
853945        negative_prompt_embeds : Optional [torch .Tensor ] =  None ,
854946        pooled_prompt_embeds : Optional [torch .Tensor ] =  None ,
855947        negative_pooled_prompt_embeds : Optional [torch .Tensor ] =  None ,
948+         ip_adapter_image : Optional [PipelineImageInput ] =  None ,
949+         ip_adapter_image_embeds : Optional [torch .Tensor ] =  None ,
856950        output_type : Optional [str ] =  "pil" ,
857951        return_dict : bool  =  True ,
952+         joint_attention_kwargs : Optional [Dict [str , Any ]] =  None ,
858953        clip_skip : Optional [int ] =  None ,
859954        callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] =  None ,
860955        callback_on_step_end_tensor_inputs : List [str ] =  ["latents" ],
@@ -890,9 +985,9 @@ def __call__(
890985            mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): 
891986                `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask 
892987                latents tensor will ge generated by `mask_image`. 
893-             height (`int`, *optional*, defaults to self.unet .config.sample_size * self.vae_scale_factor): 
988+             height (`int`, *optional*, defaults to self.transformer .config.sample_size * self.vae_scale_factor): 
894989                The height in pixels of the generated image. This is set to 1024 by default for the best results. 
895-             width (`int`, *optional*, defaults to self.unet .config.sample_size * self.vae_scale_factor): 
990+             width (`int`, *optional*, defaults to self.transformer .config.sample_size * self.vae_scale_factor): 
896991                The width in pixels of the generated image. This is set to 1024 by default for the best results. 
897992            padding_mask_crop (`int`, *optional*, defaults to `None`): 
898993                The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to 
@@ -953,12 +1048,22 @@ def __call__(
9531048                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 
9541049                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` 
9551050                input argument. 
1051+             ip_adapter_image (`PipelineImageInput`, *optional*): 
1052+                 Optional image input to work with IP Adapters. 
1053+             ip_adapter_image_embeds (`torch.Tensor`, *optional*): 
1054+                 Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images, 
1055+                 emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to 
1056+                 `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. 
9561057            output_type (`str`, *optional*, defaults to `"pil"`): 
9571058                The output format of the generate image. Choose between 
9581059                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 
9591060            return_dict (`bool`, *optional*, defaults to `True`): 
9601061                Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of 
9611062                a plain tuple. 
1063+             joint_attention_kwargs (`dict`, *optional*): 
1064+                 A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 
1065+                 `self.processor` in 
1066+                 [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 
9621067            callback_on_step_end (`Callable`, *optional*): 
9631068                A function that calls at the end of each denoising steps during the inference. The function is called 
9641069                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 
@@ -1006,6 +1111,7 @@ def __call__(
10061111
10071112        self ._guidance_scale  =  guidance_scale 
10081113        self ._clip_skip  =  clip_skip 
1114+         self ._joint_attention_kwargs  =  joint_attention_kwargs 
10091115        self ._interrupt  =  False 
10101116
10111117        # 2. Define call parameters 
@@ -1160,7 +1266,22 @@ def __call__(
11601266                f"The transformer { self .transformer .__class__ } { self .transformer .config .in_channels }  
11611267            )
11621268
1163-         # 7. Denoising loop 
1269+         # 7. Prepare image embeddings 
1270+         if  (ip_adapter_image  is  not None  and  self .is_ip_adapter_active ) or  ip_adapter_image_embeds  is  not None :
1271+             ip_adapter_image_embeds  =  self .prepare_ip_adapter_image_embeds (
1272+                 ip_adapter_image ,
1273+                 ip_adapter_image_embeds ,
1274+                 device ,
1275+                 batch_size  *  num_images_per_prompt ,
1276+                 self .do_classifier_free_guidance ,
1277+             )
1278+ 
1279+             if  self .joint_attention_kwargs  is  None :
1280+                 self ._joint_attention_kwargs  =  {"ip_adapter_image_embeds" : ip_adapter_image_embeds }
1281+             else :
1282+                 self ._joint_attention_kwargs .update (ip_adapter_image_embeds = ip_adapter_image_embeds )
1283+ 
1284+         # 8. Denoising loop 
11641285        num_warmup_steps  =  max (len (timesteps ) -  num_inference_steps  *  self .scheduler .order , 0 )
11651286        self ._num_timesteps  =  len (timesteps )
11661287        with  self .progress_bar (total = num_inference_steps ) as  progress_bar :
@@ -1181,6 +1302,7 @@ def __call__(
11811302                    timestep = timestep ,
11821303                    encoder_hidden_states = prompt_embeds ,
11831304                    pooled_projections = pooled_prompt_embeds ,
1305+                     joint_attention_kwargs = self .joint_attention_kwargs ,
11841306                    return_dict = False ,
11851307                )[0 ]
11861308
0 commit comments