1818import  PIL .Image 
1919import  torch 
2020from  transformers  import  (
21+     BaseImageProcessor ,
2122    CLIPTextModelWithProjection ,
2223    CLIPTokenizer ,
24+     PreTrainedModel ,
2325    T5EncoderModel ,
2426    T5TokenizerFast ,
2527)
2628
2729from  ...image_processor  import  PipelineImageInput , VaeImageProcessor 
28- from  ...loaders  import  FromSingleFileMixin , SD3LoraLoaderMixin 
30+ from  ...loaders  import  FromSingleFileMixin , SD3IPAdapterMixin ,  SD3LoraLoaderMixin 
2931from  ...models .autoencoders  import  AutoencoderKL 
3032from  ...models .transformers  import  SD3Transformer2DModel 
3133from  ...schedulers  import  FlowMatchEulerDiscreteScheduler 
@@ -163,7 +165,7 @@ def retrieve_timesteps(
163165    return  timesteps , num_inference_steps 
164166
165167
166- class  StableDiffusion3Img2ImgPipeline (DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin ):
168+ class  StableDiffusion3Img2ImgPipeline (DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin ,  SD3IPAdapterMixin ):
167169    r""" 
168170    Args: 
169171        transformer ([`SD3Transformer2DModel`]): 
@@ -197,8 +199,8 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
197199            [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). 
198200    """ 
199201
200-     model_cpu_offload_seq  =  "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" 
201-     _optional_components  =  []
202+     model_cpu_offload_seq  =  "text_encoder->text_encoder_2->text_encoder_3->image_encoder-> transformer->vae" 
203+     _optional_components  =  ["image_encoder" ,  "feature_extractor" ]
202204    _callback_tensor_inputs  =  ["latents" , "prompt_embeds" , "negative_prompt_embeds" , "negative_pooled_prompt_embeds" ]
203205
204206    def  __init__ (
@@ -212,6 +214,8 @@ def __init__(
212214        tokenizer_2 : CLIPTokenizer ,
213215        text_encoder_3 : T5EncoderModel ,
214216        tokenizer_3 : T5TokenizerFast ,
217+         image_encoder : PreTrainedModel  =  None ,
218+         feature_extractor : BaseImageProcessor  =  None ,
215219    ):
216220        super ().__init__ ()
217221
@@ -225,6 +229,8 @@ def __init__(
225229            tokenizer_3 = tokenizer_3 ,
226230            transformer = transformer ,
227231            scheduler = scheduler ,
232+             image_encoder = image_encoder ,
233+             feature_extractor = feature_extractor ,
228234        )
229235        self .vae_scale_factor  =  2  **  (len (self .vae .config .block_out_channels ) -  1 ) if  getattr (self , "vae" , None ) else  8 
230236        latent_channels  =  self .vae .config .latent_channels  if  getattr (self , "vae" , None ) else  16 
@@ -738,6 +744,84 @@ def num_timesteps(self):
738744    def  interrupt (self ):
739745        return  self ._interrupt 
740746
747+     # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image 
748+     def  encode_image (self , image : PipelineImageInput , device : torch .device ) ->  torch .Tensor :
749+         """Encodes the given image into a feature representation using a pre-trained image encoder. 
750+ 
751+         Args: 
752+             image (`PipelineImageInput`): 
753+                 Input image to be encoded. 
754+             device: (`torch.device`): 
755+                 Torch device. 
756+ 
757+         Returns: 
758+             `torch.Tensor`: The encoded image feature representation. 
759+         """ 
760+         if  not  isinstance (image , torch .Tensor ):
761+             image  =  self .feature_extractor (image , return_tensors = "pt" ).pixel_values 
762+ 
763+         image  =  image .to (device = device , dtype = self .dtype )
764+ 
765+         return  self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
766+ 
767+     # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds 
768+     def  prepare_ip_adapter_image_embeds (
769+         self ,
770+         ip_adapter_image : Optional [PipelineImageInput ] =  None ,
771+         ip_adapter_image_embeds : Optional [torch .Tensor ] =  None ,
772+         device : Optional [torch .device ] =  None ,
773+         num_images_per_prompt : int  =  1 ,
774+         do_classifier_free_guidance : bool  =  True ,
775+     ) ->  torch .Tensor :
776+         """Prepares image embeddings for use in the IP-Adapter. 
777+ 
778+         Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed. 
779+ 
780+         Args: 
781+             ip_adapter_image (`PipelineImageInput`, *optional*): 
782+                 The input image to extract features from for IP-Adapter. 
783+             ip_adapter_image_embeds (`torch.Tensor`, *optional*): 
784+                 Precomputed image embeddings. 
785+             device: (`torch.device`, *optional*): 
786+                 Torch device. 
787+             num_images_per_prompt (`int`, defaults to 1): 
788+                 Number of images that should be generated per prompt. 
789+             do_classifier_free_guidance (`bool`, defaults to True): 
790+                 Whether to use classifier free guidance or not. 
791+         """ 
792+         device  =  device  or  self ._execution_device 
793+ 
794+         if  ip_adapter_image_embeds  is  not None :
795+             if  do_classifier_free_guidance :
796+                 single_negative_image_embeds , single_image_embeds  =  ip_adapter_image_embeds .chunk (2 )
797+             else :
798+                 single_image_embeds  =  ip_adapter_image_embeds 
799+         elif  ip_adapter_image  is  not None :
800+             single_image_embeds  =  self .encode_image (ip_adapter_image , device )
801+             if  do_classifier_free_guidance :
802+                 single_negative_image_embeds  =  torch .zeros_like (single_image_embeds )
803+         else :
804+             raise  ValueError ("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided." )
805+ 
806+         image_embeds  =  torch .cat ([single_image_embeds ] *  num_images_per_prompt , dim = 0 )
807+ 
808+         if  do_classifier_free_guidance :
809+             negative_image_embeds  =  torch .cat ([single_negative_image_embeds ] *  num_images_per_prompt , dim = 0 )
810+             image_embeds  =  torch .cat ([negative_image_embeds , image_embeds ], dim = 0 )
811+ 
812+         return  image_embeds .to (device = device )
813+ 
814+     # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload 
815+     def  enable_sequential_cpu_offload (self , * args , ** kwargs ):
816+         if  self .image_encoder  is  not None  and  "image_encoder"  not  in self ._exclude_from_cpu_offload :
817+             logger .warning (
818+                 "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses " 
819+                 "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling " 
820+                 "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`." 
821+             )
822+ 
823+         super ().enable_sequential_cpu_offload (* args , ** kwargs )
824+ 
741825    @torch .no_grad () 
742826    @replace_example_docstring (EXAMPLE_DOC_STRING ) 
743827    def  __call__ (
@@ -763,6 +847,8 @@ def __call__(
763847        pooled_prompt_embeds : Optional [torch .FloatTensor ] =  None ,
764848        negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] =  None ,
765849        output_type : Optional [str ] =  "pil" ,
850+         ip_adapter_image : Optional [PipelineImageInput ] =  None ,
851+         ip_adapter_image_embeds : Optional [torch .Tensor ] =  None ,
766852        return_dict : bool  =  True ,
767853        joint_attention_kwargs : Optional [Dict [str , Any ]] =  None ,
768854        clip_skip : Optional [int ] =  None ,
@@ -784,9 +870,9 @@ def __call__(
784870            prompt_3 (`str` or `List[str]`, *optional*): 
785871                The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is 
786872                will be used instead 
787-             height (`int`, *optional*, defaults to self.unet .config.sample_size * self.vae_scale_factor): 
873+             height (`int`, *optional*, defaults to self.transformer .config.sample_size * self.vae_scale_factor): 
788874                The height in pixels of the generated image. This is set to 1024 by default for the best results. 
789-             width (`int`, *optional*, defaults to self.unet .config.sample_size * self.vae_scale_factor): 
875+             width (`int`, *optional*, defaults to self.transformer .config.sample_size * self.vae_scale_factor): 
790876                The width in pixels of the generated image. This is set to 1024 by default for the best results. 
791877            num_inference_steps (`int`, *optional*, defaults to 50): 
792878                The number of denoising steps. More denoising steps usually lead to a higher quality image at the 
@@ -834,6 +920,12 @@ def __call__(
834920                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 
835921                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` 
836922                input argument. 
923+             ip_adapter_image (`PipelineImageInput`, *optional*): 
924+                 Optional image input to work with IP Adapters. 
925+             ip_adapter_image_embeds (`torch.Tensor`, *optional*): 
926+                 Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images, 
927+                 emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to 
928+                 `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. 
837929            output_type (`str`, *optional*, defaults to `"pil"`): 
838930                The output format of the generate image. Choose between 
839931                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 
@@ -969,7 +1061,22 @@ def __call__(
9691061                generator ,
9701062            )
9711063
972-         # 6. Denoising loop 
1064+         # 6. Prepare image embeddings 
1065+         if  (ip_adapter_image  is  not None  and  self .is_ip_adapter_active ) or  ip_adapter_image_embeds  is  not None :
1066+             ip_adapter_image_embeds  =  self .prepare_ip_adapter_image_embeds (
1067+                 ip_adapter_image ,
1068+                 ip_adapter_image_embeds ,
1069+                 device ,
1070+                 batch_size  *  num_images_per_prompt ,
1071+                 self .do_classifier_free_guidance ,
1072+             )
1073+ 
1074+             if  self .joint_attention_kwargs  is  None :
1075+                 self ._joint_attention_kwargs  =  {"ip_adapter_image_embeds" : ip_adapter_image_embeds }
1076+             else :
1077+                 self ._joint_attention_kwargs .update (ip_adapter_image_embeds = ip_adapter_image_embeds )
1078+ 
1079+         # 7. Denoising loop 
9731080        num_warmup_steps  =  max (len (timesteps ) -  num_inference_steps  *  self .scheduler .order , 0 )
9741081        self ._num_timesteps  =  len (timesteps )
9751082        with  self .progress_bar (total = num_inference_steps ) as  progress_bar :
0 commit comments