1818import PIL .Image
1919import torch
2020from transformers import (
21+ BaseImageProcessor ,
2122 CLIPTextModelWithProjection ,
2223 CLIPTokenizer ,
24+ PreTrainedModel ,
2325 T5EncoderModel ,
2426 T5TokenizerFast ,
2527)
2628
2729from ...image_processor import PipelineImageInput , VaeImageProcessor
28- from ...loaders import FromSingleFileMixin , SD3LoraLoaderMixin
30+ from ...loaders import FromSingleFileMixin , SD3IPAdapterMixin , SD3LoraLoaderMixin
2931from ...models .autoencoders import AutoencoderKL
3032from ...models .transformers import SD3Transformer2DModel
3133from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -163,7 +165,7 @@ def retrieve_timesteps(
163165 return timesteps , num_inference_steps
164166
165167
166- class StableDiffusion3Img2ImgPipeline (DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin ):
168+ class StableDiffusion3Img2ImgPipeline (DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin , SD3IPAdapterMixin ):
167169 r"""
168170 Args:
169171 transformer ([`SD3Transformer2DModel`]):
@@ -197,8 +199,8 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
197199 [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
198200 """
199201
200- model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
201- _optional_components = []
202+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder-> transformer->vae"
203+ _optional_components = ["image_encoder" , "feature_extractor" ]
202204 _callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" , "negative_pooled_prompt_embeds" ]
203205
204206 def __init__ (
@@ -212,6 +214,8 @@ def __init__(
212214 tokenizer_2 : CLIPTokenizer ,
213215 text_encoder_3 : T5EncoderModel ,
214216 tokenizer_3 : T5TokenizerFast ,
217+ image_encoder : PreTrainedModel = None ,
218+ feature_extractor : BaseImageProcessor = None ,
215219 ):
216220 super ().__init__ ()
217221
@@ -225,6 +229,8 @@ def __init__(
225229 tokenizer_3 = tokenizer_3 ,
226230 transformer = transformer ,
227231 scheduler = scheduler ,
232+ image_encoder = image_encoder ,
233+ feature_extractor = feature_extractor ,
228234 )
229235 self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 ) if getattr (self , "vae" , None ) else 8
230236 latent_channels = self .vae .config .latent_channels if getattr (self , "vae" , None ) else 16
@@ -738,6 +744,84 @@ def num_timesteps(self):
738744 def interrupt (self ):
739745 return self ._interrupt
740746
747+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
748+ def encode_image (self , image : PipelineImageInput , device : torch .device ) -> torch .Tensor :
749+ """Encodes the given image into a feature representation using a pre-trained image encoder.
750+
751+ Args:
752+ image (`PipelineImageInput`):
753+ Input image to be encoded.
754+ device: (`torch.device`):
755+ Torch device.
756+
757+ Returns:
758+ `torch.Tensor`: The encoded image feature representation.
759+ """
760+ if not isinstance (image , torch .Tensor ):
761+ image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
762+
763+ image = image .to (device = device , dtype = self .dtype )
764+
765+ return self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
766+
767+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
768+ def prepare_ip_adapter_image_embeds (
769+ self ,
770+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
771+ ip_adapter_image_embeds : Optional [torch .Tensor ] = None ,
772+ device : Optional [torch .device ] = None ,
773+ num_images_per_prompt : int = 1 ,
774+ do_classifier_free_guidance : bool = True ,
775+ ) -> torch .Tensor :
776+ """Prepares image embeddings for use in the IP-Adapter.
777+
778+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
779+
780+ Args:
781+ ip_adapter_image (`PipelineImageInput`, *optional*):
782+ The input image to extract features from for IP-Adapter.
783+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
784+ Precomputed image embeddings.
785+ device: (`torch.device`, *optional*):
786+ Torch device.
787+ num_images_per_prompt (`int`, defaults to 1):
788+ Number of images that should be generated per prompt.
789+ do_classifier_free_guidance (`bool`, defaults to True):
790+ Whether to use classifier free guidance or not.
791+ """
792+ device = device or self ._execution_device
793+
794+ if ip_adapter_image_embeds is not None :
795+ if do_classifier_free_guidance :
796+ single_negative_image_embeds , single_image_embeds = ip_adapter_image_embeds .chunk (2 )
797+ else :
798+ single_image_embeds = ip_adapter_image_embeds
799+ elif ip_adapter_image is not None :
800+ single_image_embeds = self .encode_image (ip_adapter_image , device )
801+ if do_classifier_free_guidance :
802+ single_negative_image_embeds = torch .zeros_like (single_image_embeds )
803+ else :
804+ raise ValueError ("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided." )
805+
806+ image_embeds = torch .cat ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
807+
808+ if do_classifier_free_guidance :
809+ negative_image_embeds = torch .cat ([single_negative_image_embeds ] * num_images_per_prompt , dim = 0 )
810+ image_embeds = torch .cat ([negative_image_embeds , image_embeds ], dim = 0 )
811+
812+ return image_embeds .to (device = device )
813+
814+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
815+ def enable_sequential_cpu_offload (self , * args , ** kwargs ):
816+ if self .image_encoder is not None and "image_encoder" not in self ._exclude_from_cpu_offload :
817+ logger .warning (
818+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
819+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
820+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
821+ )
822+
823+ super ().enable_sequential_cpu_offload (* args , ** kwargs )
824+
741825 @torch .no_grad ()
742826 @replace_example_docstring (EXAMPLE_DOC_STRING )
743827 def __call__ (
@@ -763,6 +847,8 @@ def __call__(
763847 pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
764848 negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
765849 output_type : Optional [str ] = "pil" ,
850+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
851+ ip_adapter_image_embeds : Optional [torch .Tensor ] = None ,
766852 return_dict : bool = True ,
767853 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
768854 clip_skip : Optional [int ] = None ,
@@ -784,9 +870,9 @@ def __call__(
784870 prompt_3 (`str` or `List[str]`, *optional*):
785871 The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
786872 will be used instead
787- height (`int`, *optional*, defaults to self.unet .config.sample_size * self.vae_scale_factor):
873+ height (`int`, *optional*, defaults to self.transformer .config.sample_size * self.vae_scale_factor):
788874 The height in pixels of the generated image. This is set to 1024 by default for the best results.
789- width (`int`, *optional*, defaults to self.unet .config.sample_size * self.vae_scale_factor):
875+ width (`int`, *optional*, defaults to self.transformer .config.sample_size * self.vae_scale_factor):
790876 The width in pixels of the generated image. This is set to 1024 by default for the best results.
791877 num_inference_steps (`int`, *optional*, defaults to 50):
792878 The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -834,6 +920,12 @@ def __call__(
834920 Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
835921 weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
836922 input argument.
923+ ip_adapter_image (`PipelineImageInput`, *optional*):
924+ Optional image input to work with IP Adapters.
925+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
926+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
927+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
928+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
837929 output_type (`str`, *optional*, defaults to `"pil"`):
838930 The output format of the generate image. Choose between
839931 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -969,7 +1061,22 @@ def __call__(
9691061 generator ,
9701062 )
9711063
972- # 6. Denoising loop
1064+ # 6. Prepare image embeddings
1065+ if (ip_adapter_image is not None and self .is_ip_adapter_active ) or ip_adapter_image_embeds is not None :
1066+ ip_adapter_image_embeds = self .prepare_ip_adapter_image_embeds (
1067+ ip_adapter_image ,
1068+ ip_adapter_image_embeds ,
1069+ device ,
1070+ batch_size * num_images_per_prompt ,
1071+ self .do_classifier_free_guidance ,
1072+ )
1073+
1074+ if self .joint_attention_kwargs is None :
1075+ self ._joint_attention_kwargs = {"ip_adapter_image_embeds" : ip_adapter_image_embeds }
1076+ else :
1077+ self ._joint_attention_kwargs .update (ip_adapter_image_embeds = ip_adapter_image_embeds )
1078+
1079+ # 7. Denoising loop
9731080 num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
9741081 self ._num_timesteps = len (timesteps )
9751082 with self .progress_bar (total = num_inference_steps ) as progress_bar :
0 commit comments