1717
1818import  torch 
1919from  transformers  import  (
20+     BaseImageProcessor ,
2021    CLIPTextModelWithProjection ,
2122    CLIPTokenizer ,
23+     PreTrainedModel ,
2224    T5EncoderModel ,
2325    T5TokenizerFast ,
2426)
2527
2628from  ...image_processor  import  PipelineImageInput , VaeImageProcessor 
27- from  ...loaders  import  FromSingleFileMixin , SD3LoraLoaderMixin 
29+ from  ...loaders  import  FromSingleFileMixin , SD3IPAdapterMixin ,  SD3LoraLoaderMixin 
2830from  ...models .autoencoders  import  AutoencoderKL 
2931from  ...models .controlnets .controlnet_sd3  import  SD3ControlNetModel , SD3MultiControlNetModel 
3032from  ...models .transformers  import  SD3Transformer2DModel 
@@ -159,7 +161,9 @@ def retrieve_timesteps(
159161    return  timesteps , num_inference_steps 
160162
161163
162- class  StableDiffusion3ControlNetInpaintingPipeline (DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin ):
164+ class  StableDiffusion3ControlNetInpaintingPipeline (
165+     DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin , SD3IPAdapterMixin 
166+ ):
163167    r""" 
164168    Args: 
165169        transformer ([`SD3Transformer2DModel`]): 
@@ -192,13 +196,17 @@ class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoa
192196            Tokenizer of class 
193197            [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). 
194198        controlnet ([`SD3ControlNetModel`] or `List[SD3ControlNetModel]` or [`SD3MultiControlNetModel`]): 
195-             Provides additional conditioning to the `unet ` during the denoising process. If you set multiple 
199+             Provides additional conditioning to the `transformer ` during the denoising process. If you set multiple 
196200            ControlNets as a list, the outputs from each ControlNet are added together to create one combined 
197201            additional conditioning. 
202+         image_encoder (`PreTrainedModel`, *optional*): 
203+             Pre-trained Vision Model for IP Adapter. 
204+         feature_extractor (`BaseImageProcessor`, *optional*): 
205+             Image processor for IP Adapter. 
198206    """ 
199207
200-     model_cpu_offload_seq  =  "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" 
201-     _optional_components  =  []
208+     model_cpu_offload_seq  =  "text_encoder->text_encoder_2->text_encoder_3->image_encoder-> transformer->vae" 
209+     _optional_components  =  ["image_encoder" ,  "feature_extractor" ]
202210    _callback_tensor_inputs  =  ["latents" , "prompt_embeds" , "negative_prompt_embeds" , "negative_pooled_prompt_embeds" ]
203211
204212    def  __init__ (
@@ -215,6 +223,8 @@ def __init__(
215223        controlnet : Union [
216224            SD3ControlNetModel , List [SD3ControlNetModel ], Tuple [SD3ControlNetModel ], SD3MultiControlNetModel 
217225        ],
226+         image_encoder : PreTrainedModel  =  None ,
227+         feature_extractor : BaseImageProcessor  =  None ,
218228    ):
219229        super ().__init__ ()
220230
@@ -229,6 +239,8 @@ def __init__(
229239            transformer = transformer ,
230240            scheduler = scheduler ,
231241            controlnet = controlnet ,
242+             image_encoder = image_encoder ,
243+             feature_extractor = feature_extractor ,
232244        )
233245        self .vae_scale_factor  =  2  **  (len (self .vae .config .block_out_channels ) -  1 ) if  getattr (self , "vae" , None ) else  8 
234246        self .image_processor  =  VaeImageProcessor (
@@ -410,9 +422,9 @@ def encode_prompt(
410422            negative_prompt_2 (`str` or `List[str]`, *optional*): 
411423                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and 
412424                `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. 
413-             negative_prompt_2  (`str` or `List[str]`, *optional*): 
425+             negative_prompt_3  (`str` or `List[str]`, *optional*): 
414426                The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and 
415-                 `text_encoder_3`. If not defined, `negative_prompt` is used in both  text-encoders 
427+                 `text_encoder_3`. If not defined, `negative_prompt` is used in all the  text-encoders.  
416428            prompt_embeds (`torch.FloatTensor`, *optional*): 
417429                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 
418430                provided, text embeddings will be generated from `prompt` input argument. 
@@ -775,6 +787,84 @@ def num_timesteps(self):
775787    def  interrupt (self ):
776788        return  self ._interrupt 
777789
790+     # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image 
791+     def  encode_image (self , image : PipelineImageInput , device : torch .device ) ->  torch .Tensor :
792+         """Encodes the given image into a feature representation using a pre-trained image encoder. 
793+ 
794+         Args: 
795+             image (`PipelineImageInput`): 
796+                 Input image to be encoded. 
797+             device: (`torch.device`): 
798+                 Torch device. 
799+ 
800+         Returns: 
801+             `torch.Tensor`: The encoded image feature representation. 
802+         """ 
803+         if  not  isinstance (image , torch .Tensor ):
804+             image  =  self .feature_extractor (image , return_tensors = "pt" ).pixel_values 
805+ 
806+         image  =  image .to (device = device , dtype = self .dtype )
807+ 
808+         return  self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
809+ 
810+     # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds 
811+     def  prepare_ip_adapter_image_embeds (
812+         self ,
813+         ip_adapter_image : Optional [PipelineImageInput ] =  None ,
814+         ip_adapter_image_embeds : Optional [torch .Tensor ] =  None ,
815+         device : Optional [torch .device ] =  None ,
816+         num_images_per_prompt : int  =  1 ,
817+         do_classifier_free_guidance : bool  =  True ,
818+     ) ->  torch .Tensor :
819+         """Prepares image embeddings for use in the IP-Adapter. 
820+ 
821+         Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed. 
822+ 
823+         Args: 
824+             ip_adapter_image (`PipelineImageInput`, *optional*): 
825+                 The input image to extract features from for IP-Adapter. 
826+             ip_adapter_image_embeds (`torch.Tensor`, *optional*): 
827+                 Precomputed image embeddings. 
828+             device: (`torch.device`, *optional*): 
829+                 Torch device. 
830+             num_images_per_prompt (`int`, defaults to 1): 
831+                 Number of images that should be generated per prompt. 
832+             do_classifier_free_guidance (`bool`, defaults to True): 
833+                 Whether to use classifier free guidance or not. 
834+         """ 
835+         device  =  device  or  self ._execution_device 
836+ 
837+         if  ip_adapter_image_embeds  is  not None :
838+             if  do_classifier_free_guidance :
839+                 single_negative_image_embeds , single_image_embeds  =  ip_adapter_image_embeds .chunk (2 )
840+             else :
841+                 single_image_embeds  =  ip_adapter_image_embeds 
842+         elif  ip_adapter_image  is  not None :
843+             single_image_embeds  =  self .encode_image (ip_adapter_image , device )
844+             if  do_classifier_free_guidance :
845+                 single_negative_image_embeds  =  torch .zeros_like (single_image_embeds )
846+         else :
847+             raise  ValueError ("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided." )
848+ 
849+         image_embeds  =  torch .cat ([single_image_embeds ] *  num_images_per_prompt , dim = 0 )
850+ 
851+         if  do_classifier_free_guidance :
852+             negative_image_embeds  =  torch .cat ([single_negative_image_embeds ] *  num_images_per_prompt , dim = 0 )
853+             image_embeds  =  torch .cat ([negative_image_embeds , image_embeds ], dim = 0 )
854+ 
855+         return  image_embeds .to (device = device )
856+ 
857+     # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload 
858+     def  enable_sequential_cpu_offload (self , * args , ** kwargs ):
859+         if  self .image_encoder  is  not None  and  "image_encoder"  not  in self ._exclude_from_cpu_offload :
860+             logger .warning (
861+                 "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses " 
862+                 "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling " 
863+                 "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`." 
864+             )
865+ 
866+         super ().enable_sequential_cpu_offload (* args , ** kwargs )
867+ 
778868    @torch .no_grad () 
779869    @replace_example_docstring (EXAMPLE_DOC_STRING ) 
780870    def  __call__ (
@@ -803,6 +893,8 @@ def __call__(
803893        negative_prompt_embeds : Optional [torch .FloatTensor ] =  None ,
804894        pooled_prompt_embeds : Optional [torch .FloatTensor ] =  None ,
805895        negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] =  None ,
896+         ip_adapter_image : Optional [PipelineImageInput ] =  None ,
897+         ip_adapter_image_embeds : Optional [torch .Tensor ] =  None ,
806898        output_type : Optional [str ] =  "pil" ,
807899        return_dict : bool  =  True ,
808900        joint_attention_kwargs : Optional [Dict [str , Any ]] =  None ,
@@ -896,6 +988,12 @@ def __call__(
896988                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 
897989                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` 
898990                input argument. 
991+             ip_adapter_image (`PipelineImageInput`, *optional*): 
992+                 Optional image input to work with IP Adapters. 
993+             ip_adapter_image_embeds (`torch.Tensor`, *optional*): 
994+                 Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images, 
995+                 emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to 
996+                 `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. 
899997            output_type (`str`, *optional*, defaults to `"pil"`): 
900998                The output format of the generate image. Choose between 
901999                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 
@@ -1057,7 +1155,22 @@ def __call__(
10571155            ]
10581156            controlnet_keep .append (keeps [0 ] if  isinstance (self .controlnet , SD3ControlNetModel ) else  keeps )
10591157
1060-         # 7. Denoising loop 
1158+         # 7. Prepare image embeddings 
1159+         if  (ip_adapter_image  is  not None  and  self .is_ip_adapter_active ) or  ip_adapter_image_embeds  is  not None :
1160+             ip_adapter_image_embeds  =  self .prepare_ip_adapter_image_embeds (
1161+                 ip_adapter_image ,
1162+                 ip_adapter_image_embeds ,
1163+                 device ,
1164+                 batch_size  *  num_images_per_prompt ,
1165+                 self .do_classifier_free_guidance ,
1166+             )
1167+ 
1168+             if  self .joint_attention_kwargs  is  None :
1169+                 self ._joint_attention_kwargs  =  {"ip_adapter_image_embeds" : ip_adapter_image_embeds }
1170+             else :
1171+                 self ._joint_attention_kwargs .update (ip_adapter_image_embeds = ip_adapter_image_embeds )
1172+ 
1173+         # 8. Denoising loop 
10611174        with  self .progress_bar (total = num_inference_steps ) as  progress_bar :
10621175            for  i , t  in  enumerate (timesteps ):
10631176                if  self .interrupt :
0 commit comments