1717
1818import  torch 
1919from  transformers  import  (
20+     BaseImageProcessor ,
2021    CLIPTextModelWithProjection ,
2122    CLIPTokenizer ,
23+     PreTrainedModel ,
2224    T5EncoderModel ,
2325    T5TokenizerFast ,
2426)
2527
2628from  ...image_processor  import  PipelineImageInput , VaeImageProcessor 
27- from  ...loaders  import  FromSingleFileMixin , SD3LoraLoaderMixin 
29+ from  ...loaders  import  FromSingleFileMixin , SD3IPAdapterMixin ,  SD3LoraLoaderMixin 
2830from  ...models .autoencoders  import  AutoencoderKL 
2931from  ...models .controlnets .controlnet_sd3  import  SD3ControlNetModel , SD3MultiControlNetModel 
3032from  ...models .transformers  import  SD3Transformer2DModel 
@@ -138,7 +140,9 @@ def retrieve_timesteps(
138140    return  timesteps , num_inference_steps 
139141
140142
141- class  StableDiffusion3ControlNetPipeline (DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin ):
143+ class  StableDiffusion3ControlNetPipeline (
144+     DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin , SD3IPAdapterMixin 
145+ ):
142146    r""" 
143147    Args: 
144148        transformer ([`SD3Transformer2DModel`]): 
@@ -174,10 +178,14 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
174178            Provides additional conditioning to the `unet` during the denoising process. If you set multiple 
175179            ControlNets as a list, the outputs from each ControlNet are added together to create one combined 
176180            additional conditioning. 
181+         image_encoder (`PreTrainedModel`, *optional*): 
182+             Pre-trained Vision Model for IP Adapter. 
183+         feature_extractor (`BaseImageProcessor`, *optional*): 
184+             Image processor for IP Adapter. 
177185    """ 
178186
179-     model_cpu_offload_seq  =  "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" 
180-     _optional_components  =  []
187+     model_cpu_offload_seq  =  "text_encoder->text_encoder_2->text_encoder_3->image_encoder-> transformer->vae" 
188+     _optional_components  =  ["image_encoder" ,  "feature_extractor" ]
181189    _callback_tensor_inputs  =  ["latents" , "prompt_embeds" , "negative_prompt_embeds" , "negative_pooled_prompt_embeds" ]
182190
183191    def  __init__ (
@@ -194,6 +202,8 @@ def __init__(
194202        controlnet : Union [
195203            SD3ControlNetModel , List [SD3ControlNetModel ], Tuple [SD3ControlNetModel ], SD3MultiControlNetModel 
196204        ],
205+         image_encoder : PreTrainedModel  =  None ,
206+         feature_extractor : BaseImageProcessor  =  None ,
197207    ):
198208        super ().__init__ ()
199209        if  isinstance (controlnet , (list , tuple )):
@@ -223,6 +233,8 @@ def __init__(
223233            transformer = transformer ,
224234            scheduler = scheduler ,
225235            controlnet = controlnet ,
236+             image_encoder = image_encoder ,
237+             feature_extractor = feature_extractor ,
226238        )
227239        self .vae_scale_factor  =  (
228240            2  **  (len (self .vae .config .block_out_channels ) -  1 ) if  hasattr (self , "vae" ) and  self .vae  is  not None  else  8 
@@ -727,6 +739,83 @@ def num_timesteps(self):
727739    def  interrupt (self ):
728740        return  self ._interrupt 
729741
742+     # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image 
743+     def  encode_image (self , image : PipelineImageInput , device : torch .device ) ->  torch .Tensor :
744+         """Encodes the given image into a feature representation using a pre-trained image encoder. 
745+ 
746+         Args: 
747+             image (`PipelineImageInput`): 
748+                 Input image to be encoded. 
749+             device: (`torch.device`): 
750+                 Torch device. 
751+ 
752+         Returns: 
753+             `torch.Tensor`: The encoded image feature representation. 
754+         """ 
755+         if  not  isinstance (image , torch .Tensor ):
756+             image  =  self .feature_extractor (image , return_tensors = "pt" ).pixel_values 
757+ 
758+         image  =  image .to (device = device , dtype = self .dtype )
759+ 
760+         return  self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
761+ 
762+     # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds 
763+     def  prepare_ip_adapter_image_embeds (
764+         self ,
765+         ip_adapter_image : Optional [PipelineImageInput ] =  None ,
766+         ip_adapter_image_embeds : Optional [torch .Tensor ] =  None ,
767+         device : Optional [torch .device ] =  None ,
768+         num_images_per_prompt : int  =  1 ,
769+         do_classifier_free_guidance : bool  =  True ,
770+     ) ->  torch .Tensor :
771+         """Prepares image embeddings for use in the IP-Adapter. 
772+ 
773+         Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed. 
774+ 
775+         Args: 
776+             ip_adapter_image (`PipelineImageInput`, *optional*): 
777+                 The input image to extract features from for IP-Adapter. 
778+             ip_adapter_image_embeds (`torch.Tensor`, *optional*): 
779+                 Precomputed image embeddings. 
780+             device: (`torch.device`, *optional*): 
781+                 Torch device. 
782+             num_images_per_prompt (`int`, defaults to 1): 
783+                 Number of images that should be generated per prompt. 
784+             do_classifier_free_guidance (`bool`, defaults to True): 
785+                 Whether to use classifier free guidance or not. 
786+         """ 
787+         device  =  device  or  self ._execution_device 
788+ 
789+         if  ip_adapter_image_embeds  is  not None :
790+             if  do_classifier_free_guidance :
791+                 single_negative_image_embeds , single_image_embeds  =  ip_adapter_image_embeds .chunk (2 )
792+             else :
793+                 single_image_embeds  =  ip_adapter_image_embeds 
794+         elif  ip_adapter_image  is  not None :
795+             single_image_embeds  =  self .encode_image (ip_adapter_image , device )
796+             if  do_classifier_free_guidance :
797+                 single_negative_image_embeds  =  torch .zeros_like (single_image_embeds )
798+         else :
799+             raise  ValueError ("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided." )
800+ 
801+         image_embeds  =  torch .cat ([single_image_embeds ] *  num_images_per_prompt , dim = 0 )
802+ 
803+         if  do_classifier_free_guidance :
804+             negative_image_embeds  =  torch .cat ([single_negative_image_embeds ] *  num_images_per_prompt , dim = 0 )
805+             image_embeds  =  torch .cat ([negative_image_embeds , image_embeds ], dim = 0 )
806+ 
807+         return  image_embeds .to (device = device )
808+ 
809+     def  enable_sequential_cpu_offload (self , * args , ** kwargs ):
810+         if  self .image_encoder  is  not None  and  "image_encoder"  not  in self ._exclude_from_cpu_offload :
811+             logger .warning (
812+                 "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses " 
813+                 "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling " 
814+                 "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`." 
815+             )
816+ 
817+         super ().enable_sequential_cpu_offload (* args , ** kwargs )
818+ 
730819    @torch .no_grad () 
731820    @replace_example_docstring (EXAMPLE_DOC_STRING ) 
732821    def  __call__ (
@@ -754,6 +843,8 @@ def __call__(
754843        negative_prompt_embeds : Optional [torch .FloatTensor ] =  None ,
755844        pooled_prompt_embeds : Optional [torch .FloatTensor ] =  None ,
756845        negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] =  None ,
846+         ip_adapter_image : Optional [PipelineImageInput ] =  None ,
847+         ip_adapter_image_embeds : Optional [torch .Tensor ] =  None ,
757848        output_type : Optional [str ] =  "pil" ,
758849        return_dict : bool  =  True ,
759850        joint_attention_kwargs : Optional [Dict [str , Any ]] =  None ,
@@ -843,6 +934,12 @@ def __call__(
843934                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 
844935                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` 
845936                input argument. 
937+             ip_adapter_image (`PipelineImageInput`, *optional*): 
938+                 Optional image input to work with IP Adapters. 
939+             ip_adapter_image_embeds (`torch.Tensor`, *optional*): 
940+                 Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images, 
941+                 emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to 
942+                 `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. 
846943            output_type (`str`, *optional*, defaults to `"pil"`): 
847944                The output format of the generate image. Choose between 
848945                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 
@@ -1040,7 +1137,22 @@ def __call__(
10401137            # SD35 official 8b controlnet does not use encoder_hidden_states 
10411138            controlnet_encoder_hidden_states  =  None 
10421139
1043-         # 7. Denoising loop 
1140+         # 7. Prepare image embeddings 
1141+         if  (ip_adapter_image  is  not None  and  self .is_ip_adapter_active ) or  ip_adapter_image_embeds  is  not None :
1142+             ip_adapter_image_embeds  =  self .prepare_ip_adapter_image_embeds (
1143+                 ip_adapter_image ,
1144+                 ip_adapter_image_embeds ,
1145+                 device ,
1146+                 batch_size  *  num_images_per_prompt ,
1147+                 self .do_classifier_free_guidance ,
1148+             )
1149+ 
1150+             if  self .joint_attention_kwargs  is  None :
1151+                 self ._joint_attention_kwargs  =  {"ip_adapter_image_embeds" : ip_adapter_image_embeds }
1152+             else :
1153+                 self ._joint_attention_kwargs .update (ip_adapter_image_embeds = ip_adapter_image_embeds )
1154+ 
1155+         # 8. Denoising loop 
10441156        with  self .progress_bar (total = num_inference_steps ) as  progress_bar :
10451157            for  i , t  in  enumerate (timesteps ):
10461158                if  self .interrupt :
0 commit comments