1717
1818import torch
1919from transformers import (
20+ BaseImageProcessor ,
2021 CLIPTextModelWithProjection ,
2122 CLIPTokenizer ,
23+ PreTrainedModel ,
2224 T5EncoderModel ,
2325 T5TokenizerFast ,
2426)
2527
2628from ...image_processor import PipelineImageInput , VaeImageProcessor
27- from ...loaders import FromSingleFileMixin , SD3LoraLoaderMixin
29+ from ...loaders import FromSingleFileMixin , SD3IPAdapterMixin , SD3LoraLoaderMixin
2830from ...models .autoencoders import AutoencoderKL
2931from ...models .controlnets .controlnet_sd3 import SD3ControlNetModel , SD3MultiControlNetModel
3032from ...models .transformers import SD3Transformer2DModel
@@ -138,7 +140,9 @@ def retrieve_timesteps(
138140 return timesteps , num_inference_steps
139141
140142
141- class StableDiffusion3ControlNetPipeline (DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin ):
143+ class StableDiffusion3ControlNetPipeline (
144+ DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin , SD3IPAdapterMixin
145+ ):
142146 r"""
143147 Args:
144148 transformer ([`SD3Transformer2DModel`]):
@@ -174,10 +178,14 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
174178 Provides additional conditioning to the `unet` during the denoising process. If you set multiple
175179 ControlNets as a list, the outputs from each ControlNet are added together to create one combined
176180 additional conditioning.
181+ image_encoder (`PreTrainedModel`, *optional*):
182+ Pre-trained Vision Model for IP Adapter.
183+ feature_extractor (`BaseImageProcessor`, *optional*):
184+ Image processor for IP Adapter.
177185 """
178186
179- model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
180- _optional_components = []
187+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder-> transformer->vae"
188+ _optional_components = ["image_encoder" , "feature_extractor" ]
181189 _callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" , "negative_pooled_prompt_embeds" ]
182190
183191 def __init__ (
@@ -194,6 +202,8 @@ def __init__(
194202 controlnet : Union [
195203 SD3ControlNetModel , List [SD3ControlNetModel ], Tuple [SD3ControlNetModel ], SD3MultiControlNetModel
196204 ],
205+ image_encoder : PreTrainedModel = None ,
206+ feature_extractor : BaseImageProcessor = None ,
197207 ):
198208 super ().__init__ ()
199209 if isinstance (controlnet , (list , tuple )):
@@ -223,6 +233,8 @@ def __init__(
223233 transformer = transformer ,
224234 scheduler = scheduler ,
225235 controlnet = controlnet ,
236+ image_encoder = image_encoder ,
237+ feature_extractor = feature_extractor ,
226238 )
227239 self .vae_scale_factor = (
228240 2 ** (len (self .vae .config .block_out_channels ) - 1 ) if hasattr (self , "vae" ) and self .vae is not None else 8
@@ -727,6 +739,84 @@ def num_timesteps(self):
727739 def interrupt (self ):
728740 return self ._interrupt
729741
742+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
743+ def encode_image (self , image : PipelineImageInput , device : torch .device ) -> torch .Tensor :
744+ """Encodes the given image into a feature representation using a pre-trained image encoder.
745+
746+ Args:
747+ image (`PipelineImageInput`):
748+ Input image to be encoded.
749+ device: (`torch.device`):
750+ Torch device.
751+
752+ Returns:
753+ `torch.Tensor`: The encoded image feature representation.
754+ """
755+ if not isinstance (image , torch .Tensor ):
756+ image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
757+
758+ image = image .to (device = device , dtype = self .dtype )
759+
760+ return self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
761+
762+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
763+ def prepare_ip_adapter_image_embeds (
764+ self ,
765+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
766+ ip_adapter_image_embeds : Optional [torch .Tensor ] = None ,
767+ device : Optional [torch .device ] = None ,
768+ num_images_per_prompt : int = 1 ,
769+ do_classifier_free_guidance : bool = True ,
770+ ) -> torch .Tensor :
771+ """Prepares image embeddings for use in the IP-Adapter.
772+
773+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
774+
775+ Args:
776+ ip_adapter_image (`PipelineImageInput`, *optional*):
777+ The input image to extract features from for IP-Adapter.
778+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
779+ Precomputed image embeddings.
780+ device: (`torch.device`, *optional*):
781+ Torch device.
782+ num_images_per_prompt (`int`, defaults to 1):
783+ Number of images that should be generated per prompt.
784+ do_classifier_free_guidance (`bool`, defaults to True):
785+ Whether to use classifier free guidance or not.
786+ """
787+ device = device or self ._execution_device
788+
789+ if ip_adapter_image_embeds is not None :
790+ if do_classifier_free_guidance :
791+ single_negative_image_embeds , single_image_embeds = ip_adapter_image_embeds .chunk (2 )
792+ else :
793+ single_image_embeds = ip_adapter_image_embeds
794+ elif ip_adapter_image is not None :
795+ single_image_embeds = self .encode_image (ip_adapter_image , device )
796+ if do_classifier_free_guidance :
797+ single_negative_image_embeds = torch .zeros_like (single_image_embeds )
798+ else :
799+ raise ValueError ("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided." )
800+
801+ image_embeds = torch .cat ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
802+
803+ if do_classifier_free_guidance :
804+ negative_image_embeds = torch .cat ([single_negative_image_embeds ] * num_images_per_prompt , dim = 0 )
805+ image_embeds = torch .cat ([negative_image_embeds , image_embeds ], dim = 0 )
806+
807+ return image_embeds .to (device = device )
808+
809+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
810+ def enable_sequential_cpu_offload (self , * args , ** kwargs ):
811+ if self .image_encoder is not None and "image_encoder" not in self ._exclude_from_cpu_offload :
812+ logger .warning (
813+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
814+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
815+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
816+ )
817+
818+ super ().enable_sequential_cpu_offload (* args , ** kwargs )
819+
730820 @torch .no_grad ()
731821 @replace_example_docstring (EXAMPLE_DOC_STRING )
732822 def __call__ (
@@ -754,6 +844,8 @@ def __call__(
754844 negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
755845 pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
756846 negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
847+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
848+ ip_adapter_image_embeds : Optional [torch .Tensor ] = None ,
757849 output_type : Optional [str ] = "pil" ,
758850 return_dict : bool = True ,
759851 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -843,6 +935,12 @@ def __call__(
843935 Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
844936 weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
845937 input argument.
938+ ip_adapter_image (`PipelineImageInput`, *optional*):
939+ Optional image input to work with IP Adapters.
940+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
941+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
942+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
943+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
846944 output_type (`str`, *optional*, defaults to `"pil"`):
847945 The output format of the generate image. Choose between
848946 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -1040,7 +1138,22 @@ def __call__(
10401138 # SD35 official 8b controlnet does not use encoder_hidden_states
10411139 controlnet_encoder_hidden_states = None
10421140
1043- # 7. Denoising loop
1141+ # 7. Prepare image embeddings
1142+ if (ip_adapter_image is not None and self .is_ip_adapter_active ) or ip_adapter_image_embeds is not None :
1143+ ip_adapter_image_embeds = self .prepare_ip_adapter_image_embeds (
1144+ ip_adapter_image ,
1145+ ip_adapter_image_embeds ,
1146+ device ,
1147+ batch_size * num_images_per_prompt ,
1148+ self .do_classifier_free_guidance ,
1149+ )
1150+
1151+ if self .joint_attention_kwargs is None :
1152+ self ._joint_attention_kwargs = {"ip_adapter_image_embeds" : ip_adapter_image_embeds }
1153+ else :
1154+ self ._joint_attention_kwargs .update (ip_adapter_image_embeds = ip_adapter_image_embeds )
1155+
1156+ # 8. Denoising loop
10441157 with self .progress_bar (total = num_inference_steps ) as progress_bar :
10451158 for i , t in enumerate (timesteps ):
10461159 if self .interrupt :
0 commit comments