1717
1818import  torch 
1919from  transformers  import  (
20+     BaseImageProcessor ,
2021    CLIPTextModelWithProjection ,
2122    CLIPTokenizer ,
23+     PreTrainedModel ,
2224    T5EncoderModel ,
2325    T5TokenizerFast ,
2426)
2527
2628from  ...image_processor  import  PipelineImageInput , VaeImageProcessor 
27- from  ...loaders  import  FromSingleFileMixin , SD3LoraLoaderMixin 
29+ from  ...loaders  import  FromSingleFileMixin , SD3IPAdapterMixin ,  SD3LoraLoaderMixin 
2830from  ...models .autoencoders  import  AutoencoderKL 
2931from  ...models .controlnets .controlnet_sd3  import  SD3ControlNetModel , SD3MultiControlNetModel 
3032from  ...models .transformers  import  SD3Transformer2DModel 
@@ -159,7 +161,9 @@ def retrieve_timesteps(
159161    return  timesteps , num_inference_steps 
160162
161163
162- class  StableDiffusion3ControlNetInpaintingPipeline (DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin ):
164+ class  StableDiffusion3ControlNetInpaintingPipeline (
165+     DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin , SD3IPAdapterMixin 
166+ ):
163167    r""" 
164168    Args: 
165169        transformer ([`SD3Transformer2DModel`]): 
@@ -192,13 +196,17 @@ class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoa
192196            Tokenizer of class 
193197            [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). 
194198        controlnet ([`SD3ControlNetModel`] or `List[SD3ControlNetModel]` or [`SD3MultiControlNetModel`]): 
195-             Provides additional conditioning to the `unet ` during the denoising process. If you set multiple 
199+             Provides additional conditioning to the `transformer ` during the denoising process. If you set multiple 
196200            ControlNets as a list, the outputs from each ControlNet are added together to create one combined 
197201            additional conditioning. 
202+         image_encoder (`PreTrainedModel`, *optional*): 
203+             Pre-trained Vision Model for IP Adapter. 
204+         feature_extractor (`BaseImageProcessor`, *optional*): 
205+             Image processor for IP Adapter. 
198206    """ 
199207
200-     model_cpu_offload_seq  =  "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" 
201-     _optional_components  =  []
208+     model_cpu_offload_seq  =  "text_encoder->text_encoder_2->text_encoder_3->image_encoder-> transformer->vae" 
209+     _optional_components  =  ["image_encoder" ,  "feature_extractor" ]
202210    _callback_tensor_inputs  =  ["latents" , "prompt_embeds" , "negative_prompt_embeds" , "negative_pooled_prompt_embeds" ]
203211
204212    def  __init__ (
@@ -215,6 +223,8 @@ def __init__(
215223        controlnet : Union [
216224            SD3ControlNetModel , List [SD3ControlNetModel ], Tuple [SD3ControlNetModel ], SD3MultiControlNetModel 
217225        ],
226+         image_encoder : PreTrainedModel  =  None ,
227+         feature_extractor : BaseImageProcessor  =  None ,
218228    ):
219229        super ().__init__ ()
220230
@@ -229,6 +239,8 @@ def __init__(
229239            transformer = transformer ,
230240            scheduler = scheduler ,
231241            controlnet = controlnet ,
242+             image_encoder = image_encoder ,
243+             feature_extractor = feature_extractor ,
232244        )
233245        self .vae_scale_factor  =  2  **  (len (self .vae .config .block_out_channels ) -  1 ) if  getattr (self , "vae" , None ) else  8 
234246        self .image_processor  =  VaeImageProcessor (
@@ -775,6 +787,82 @@ def num_timesteps(self):
775787    def  interrupt (self ):
776788        return  self ._interrupt 
777789
790+     # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image 
791+     def  encode_image (self , image : PipelineImageInput , device : torch .device ) ->  torch .Tensor :
792+         """Encodes the given image into a feature representation using a pre-trained image encoder. 
793+ 
794+         Args: 
795+             image (`PipelineImageInput`): 
796+                 Input image to be encoded. 
797+             device: (`torch.device`): 
798+                 Torch device. 
799+         Returns: 
800+             `torch.Tensor`: The encoded image feature representation. 
801+         """ 
802+         if  not  isinstance (image , torch .Tensor ):
803+             image  =  self .feature_extractor (image , return_tensors = "pt" ).pixel_values 
804+ 
805+         image  =  image .to (device = device , dtype = self .dtype )
806+ 
807+         return  self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
808+ 
809+     # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds 
810+     def  prepare_ip_adapter_image_embeds (
811+         self ,
812+         ip_adapter_image : Optional [PipelineImageInput ] =  None ,
813+         ip_adapter_image_embeds : Optional [torch .Tensor ] =  None ,
814+         device : Optional [torch .device ] =  None ,
815+         num_images_per_prompt : int  =  1 ,
816+         do_classifier_free_guidance : bool  =  True ,
817+     ) ->  torch .Tensor :
818+         """Prepares image embeddings for use in the IP-Adapter. 
819+         Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed. 
820+ 
821+         Args: 
822+             ip_adapter_image (`PipelineImageInput`, *optional*): 
823+                 The input image to extract features from for IP-Adapter. 
824+             ip_adapter_image_embeds (`torch.Tensor`, *optional*): 
825+                 Precomputed image embeddings. 
826+             device: (`torch.device`, *optional*): 
827+                 Torch device. 
828+             num_images_per_prompt (`int`, defaults to 1): 
829+                 Number of images that should be generated per prompt. 
830+             do_classifier_free_guidance (`bool`, defaults to True): 
831+                 Whether to use classifier free guidance or not. 
832+         """ 
833+         device  =  device  or  self ._execution_device 
834+ 
835+         if  ip_adapter_image_embeds  is  not None :
836+             if  do_classifier_free_guidance :
837+                 single_negative_image_embeds , single_image_embeds  =  ip_adapter_image_embeds .chunk (2 )
838+             else :
839+                 single_image_embeds  =  ip_adapter_image_embeds 
840+         elif  ip_adapter_image  is  not None :
841+             single_image_embeds  =  self .encode_image (ip_adapter_image , device )
842+             if  do_classifier_free_guidance :
843+                 single_negative_image_embeds  =  torch .zeros_like (single_image_embeds )
844+         else :
845+             raise  ValueError ("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided." )
846+ 
847+         image_embeds  =  torch .cat ([single_image_embeds ] *  num_images_per_prompt , dim = 0 )
848+ 
849+         if  do_classifier_free_guidance :
850+             negative_image_embeds  =  torch .cat ([single_negative_image_embeds ] *  num_images_per_prompt , dim = 0 )
851+             image_embeds  =  torch .cat ([negative_image_embeds , image_embeds ], dim = 0 )
852+ 
853+         return  image_embeds .to (device = device )
854+ 
855+     # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload 
856+     def  enable_sequential_cpu_offload (self , * args , ** kwargs ):
857+         if  self .image_encoder  is  not None  and  "image_encoder"  not  in self ._exclude_from_cpu_offload :
858+             logger .warning (
859+                 "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses " 
860+                 "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling " 
861+                 "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`." 
862+             )
863+ 
864+         super ().enable_sequential_cpu_offload (* args , ** kwargs )
865+ 
778866    @torch .no_grad () 
779867    @replace_example_docstring (EXAMPLE_DOC_STRING ) 
780868    def  __call__ (
@@ -803,6 +891,8 @@ def __call__(
803891        negative_prompt_embeds : Optional [torch .FloatTensor ] =  None ,
804892        pooled_prompt_embeds : Optional [torch .FloatTensor ] =  None ,
805893        negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] =  None ,
894+         ip_adapter_image : Optional [PipelineImageInput ] =  None ,
895+         ip_adapter_image_embeds : Optional [torch .Tensor ] =  None ,
806896        output_type : Optional [str ] =  "pil" ,
807897        return_dict : bool  =  True ,
808898        joint_attention_kwargs : Optional [Dict [str , Any ]] =  None ,
@@ -896,6 +986,12 @@ def __call__(
896986                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 
897987                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` 
898988                input argument. 
989+             ip_adapter_image (`PipelineImageInput`, *optional*): 
990+                 Optional image input to work with IP Adapters. 
991+             ip_adapter_image_embeds (`torch.Tensor`, *optional*): 
992+                 Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images, 
993+                 emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to 
994+                 `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. 
899995            output_type (`str`, *optional*, defaults to `"pil"`): 
900996                The output format of the generate image. Choose between 
901997                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 
@@ -1057,7 +1153,22 @@ def __call__(
10571153            ]
10581154            controlnet_keep .append (keeps [0 ] if  isinstance (self .controlnet , SD3ControlNetModel ) else  keeps )
10591155
1060-         # 7. Denoising loop 
1156+         # 7. Prepare image embeddings 
1157+         if  (ip_adapter_image  is  not None  and  self .is_ip_adapter_active ) or  ip_adapter_image_embeds  is  not None :
1158+             ip_adapter_image_embeds  =  self .prepare_ip_adapter_image_embeds (
1159+                 ip_adapter_image ,
1160+                 ip_adapter_image_embeds ,
1161+                 device ,
1162+                 batch_size  *  num_images_per_prompt ,
1163+                 self .do_classifier_free_guidance ,
1164+             )
1165+ 
1166+             if  self .joint_attention_kwargs  is  None :
1167+                 self ._joint_attention_kwargs  =  {"ip_adapter_image_embeds" : ip_adapter_image_embeds }
1168+             else :
1169+                 self ._joint_attention_kwargs .update (ip_adapter_image_embeds = ip_adapter_image_embeds )
1170+ 
1171+         # 8. Denoising loop 
10611172        with  self .progress_bar (total = num_inference_steps ) as  progress_bar :
10621173            for  i , t  in  enumerate (timesteps ):
10631174                if  self .interrupt :
0 commit comments