Skip to content

Commit e8114bd

Browse files
authored
IP-Adapter for StableDiffusion3Img2ImgPipeline (#10589)
Added support for IP-Adapter
1 parent b0c8973 commit e8114bd

File tree

2 files changed

+116
-7
lines changed

2 files changed

+116
-7
lines changed

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py

Lines changed: 114 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,16 @@
1818
import PIL.Image
1919
import torch
2020
from transformers import (
21+
BaseImageProcessor,
2122
CLIPTextModelWithProjection,
2223
CLIPTokenizer,
24+
PreTrainedModel,
2325
T5EncoderModel,
2426
T5TokenizerFast,
2527
)
2628

2729
from ...image_processor import PipelineImageInput, VaeImageProcessor
28-
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
30+
from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
2931
from ...models.autoencoders import AutoencoderKL
3032
from ...models.transformers import SD3Transformer2DModel
3133
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -163,7 +165,7 @@ def retrieve_timesteps(
163165
return timesteps, num_inference_steps
164166

165167

166-
class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
168+
class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
167169
r"""
168170
Args:
169171
transformer ([`SD3Transformer2DModel`]):
@@ -197,8 +199,8 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
197199
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
198200
"""
199201

200-
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
201-
_optional_components = []
202+
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
203+
_optional_components = ["image_encoder", "feature_extractor"]
202204
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
203205

204206
def __init__(
@@ -212,6 +214,8 @@ def __init__(
212214
tokenizer_2: CLIPTokenizer,
213215
text_encoder_3: T5EncoderModel,
214216
tokenizer_3: T5TokenizerFast,
217+
image_encoder: PreTrainedModel = None,
218+
feature_extractor: BaseImageProcessor = None,
215219
):
216220
super().__init__()
217221

@@ -225,6 +229,8 @@ def __init__(
225229
tokenizer_3=tokenizer_3,
226230
transformer=transformer,
227231
scheduler=scheduler,
232+
image_encoder=image_encoder,
233+
feature_extractor=feature_extractor,
228234
)
229235
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
230236
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
@@ -738,6 +744,84 @@ def num_timesteps(self):
738744
def interrupt(self):
739745
return self._interrupt
740746

747+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
748+
def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
749+
"""Encodes the given image into a feature representation using a pre-trained image encoder.
750+
751+
Args:
752+
image (`PipelineImageInput`):
753+
Input image to be encoded.
754+
device: (`torch.device`):
755+
Torch device.
756+
757+
Returns:
758+
`torch.Tensor`: The encoded image feature representation.
759+
"""
760+
if not isinstance(image, torch.Tensor):
761+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
762+
763+
image = image.to(device=device, dtype=self.dtype)
764+
765+
return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
766+
767+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
768+
def prepare_ip_adapter_image_embeds(
769+
self,
770+
ip_adapter_image: Optional[PipelineImageInput] = None,
771+
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
772+
device: Optional[torch.device] = None,
773+
num_images_per_prompt: int = 1,
774+
do_classifier_free_guidance: bool = True,
775+
) -> torch.Tensor:
776+
"""Prepares image embeddings for use in the IP-Adapter.
777+
778+
Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
779+
780+
Args:
781+
ip_adapter_image (`PipelineImageInput`, *optional*):
782+
The input image to extract features from for IP-Adapter.
783+
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
784+
Precomputed image embeddings.
785+
device: (`torch.device`, *optional*):
786+
Torch device.
787+
num_images_per_prompt (`int`, defaults to 1):
788+
Number of images that should be generated per prompt.
789+
do_classifier_free_guidance (`bool`, defaults to True):
790+
Whether to use classifier free guidance or not.
791+
"""
792+
device = device or self._execution_device
793+
794+
if ip_adapter_image_embeds is not None:
795+
if do_classifier_free_guidance:
796+
single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
797+
else:
798+
single_image_embeds = ip_adapter_image_embeds
799+
elif ip_adapter_image is not None:
800+
single_image_embeds = self.encode_image(ip_adapter_image, device)
801+
if do_classifier_free_guidance:
802+
single_negative_image_embeds = torch.zeros_like(single_image_embeds)
803+
else:
804+
raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
805+
806+
image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
807+
808+
if do_classifier_free_guidance:
809+
negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
810+
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
811+
812+
return image_embeds.to(device=device)
813+
814+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
815+
def enable_sequential_cpu_offload(self, *args, **kwargs):
816+
if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
817+
logger.warning(
818+
"`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
819+
"`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
820+
"`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
821+
)
822+
823+
super().enable_sequential_cpu_offload(*args, **kwargs)
824+
741825
@torch.no_grad()
742826
@replace_example_docstring(EXAMPLE_DOC_STRING)
743827
def __call__(
@@ -763,6 +847,8 @@ def __call__(
763847
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
764848
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
765849
output_type: Optional[str] = "pil",
850+
ip_adapter_image: Optional[PipelineImageInput] = None,
851+
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
766852
return_dict: bool = True,
767853
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
768854
clip_skip: Optional[int] = None,
@@ -784,9 +870,9 @@ def __call__(
784870
prompt_3 (`str` or `List[str]`, *optional*):
785871
The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
786872
will be used instead
787-
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
873+
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
788874
The height in pixels of the generated image. This is set to 1024 by default for the best results.
789-
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
875+
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
790876
The width in pixels of the generated image. This is set to 1024 by default for the best results.
791877
num_inference_steps (`int`, *optional*, defaults to 50):
792878
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -834,6 +920,12 @@ def __call__(
834920
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
835921
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
836922
input argument.
923+
ip_adapter_image (`PipelineImageInput`, *optional*):
924+
Optional image input to work with IP Adapters.
925+
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
926+
Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
927+
emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
928+
`True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
837929
output_type (`str`, *optional*, defaults to `"pil"`):
838930
The output format of the generate image. Choose between
839931
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -969,7 +1061,22 @@ def __call__(
9691061
generator,
9701062
)
9711063

972-
# 6. Denoising loop
1064+
# 6. Prepare image embeddings
1065+
if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
1066+
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
1067+
ip_adapter_image,
1068+
ip_adapter_image_embeds,
1069+
device,
1070+
batch_size * num_images_per_prompt,
1071+
self.do_classifier_free_guidance,
1072+
)
1073+
1074+
if self.joint_attention_kwargs is None:
1075+
self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
1076+
else:
1077+
self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
1078+
1079+
# 7. Denoising loop
9731080
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
9741081
self._num_timesteps = len(timesteps)
9751082
with self.progress_bar(total=num_inference_steps) as progress_bar:

tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def get_dummy_components(self):
105105
"tokenizer_3": tokenizer_3,
106106
"transformer": transformer,
107107
"vae": vae,
108+
"image_encoder": None,
109+
"feature_extractor": None,
108110
}
109111

110112
def get_dummy_inputs(self, device, seed=0):

0 commit comments

Comments
 (0)