Skip to content

Commit 64f6991

Browse files
committed
Added support for IP-Adapter
1 parent 9f06a0d commit 64f6991

File tree

2 files changed

+120
-7
lines changed

2 files changed

+120
-7
lines changed

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py

Lines changed: 118 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,17 @@
1717

1818
import torch
1919
from transformers import (
20+
BaseImageProcessor,
2021
CLIPTextModelWithProjection,
2122
CLIPTokenizer,
23+
PreTrainedModel,
2224
T5EncoderModel,
2325
T5TokenizerFast,
2426
)
2527

2628
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2729
from ...image_processor import PipelineImageInput, VaeImageProcessor
28-
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
30+
from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
2931
from ...models.autoencoders import AutoencoderKL
3032
from ...models.transformers import SD3Transformer2DModel
3133
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -162,7 +164,7 @@ def retrieve_timesteps(
162164
return timesteps, num_inference_steps
163165

164166

165-
class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
167+
class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
166168
r"""
167169
Args:
168170
transformer ([`SD3Transformer2DModel`]):
@@ -194,10 +196,14 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
194196
tokenizer_3 (`T5TokenizerFast`):
195197
Tokenizer of class
196198
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
199+
image_encoder (`PreTrainedModel`, *optional*):
200+
Pre-trained Vision Model for IP Adapter.
201+
feature_extractor (`BaseImageProcessor`, *optional*):
202+
Image processor for IP Adapter.
197203
"""
198204

199-
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
200-
_optional_components = []
205+
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
206+
_optional_components = ["image_encoder", "feature_extractor"]
201207
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
202208

203209
def __init__(
@@ -211,6 +217,8 @@ def __init__(
211217
tokenizer_2: CLIPTokenizer,
212218
text_encoder_3: T5EncoderModel,
213219
tokenizer_3: T5TokenizerFast,
220+
image_encoder: PreTrainedModel = None,
221+
feature_extractor: BaseImageProcessor = None,
214222
):
215223
super().__init__()
216224

@@ -224,6 +232,8 @@ def __init__(
224232
tokenizer_3=tokenizer_3,
225233
transformer=transformer,
226234
scheduler=scheduler,
235+
image_encoder=image_encoder,
236+
feature_extractor=feature_extractor,
227237
)
228238
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
229239
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
@@ -826,6 +836,84 @@ def num_timesteps(self):
826836
def interrupt(self):
827837
return self._interrupt
828838

839+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
840+
def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
841+
"""Encodes the given image into a feature representation using a pre-trained image encoder.
842+
843+
Args:
844+
image (`PipelineImageInput`):
845+
Input image to be encoded.
846+
device: (`torch.device`):
847+
Torch device.
848+
849+
Returns:
850+
`torch.Tensor`: The encoded image feature representation.
851+
"""
852+
if not isinstance(image, torch.Tensor):
853+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
854+
855+
image = image.to(device=device, dtype=self.dtype)
856+
857+
return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
858+
859+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
860+
def prepare_ip_adapter_image_embeds(
861+
self,
862+
ip_adapter_image: Optional[PipelineImageInput] = None,
863+
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
864+
device: Optional[torch.device] = None,
865+
num_images_per_prompt: int = 1,
866+
do_classifier_free_guidance: bool = True,
867+
) -> torch.Tensor:
868+
"""Prepares image embeddings for use in the IP-Adapter.
869+
870+
Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
871+
872+
Args:
873+
ip_adapter_image (`PipelineImageInput`, *optional*):
874+
The input image to extract features from for IP-Adapter.
875+
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
876+
Precomputed image embeddings.
877+
device: (`torch.device`, *optional*):
878+
Torch device.
879+
num_images_per_prompt (`int`, defaults to 1):
880+
Number of images that should be generated per prompt.
881+
do_classifier_free_guidance (`bool`, defaults to True):
882+
Whether to use classifier free guidance or not.
883+
"""
884+
device = device or self._execution_device
885+
886+
if ip_adapter_image_embeds is not None:
887+
if do_classifier_free_guidance:
888+
single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
889+
else:
890+
single_image_embeds = ip_adapter_image_embeds
891+
elif ip_adapter_image is not None:
892+
single_image_embeds = self.encode_image(ip_adapter_image, device)
893+
if do_classifier_free_guidance:
894+
single_negative_image_embeds = torch.zeros_like(single_image_embeds)
895+
else:
896+
raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
897+
898+
image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
899+
900+
if do_classifier_free_guidance:
901+
negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
902+
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
903+
904+
return image_embeds.to(device=device)
905+
906+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
907+
def enable_sequential_cpu_offload(self, *args, **kwargs):
908+
if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
909+
logger.warning(
910+
"`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
911+
"`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
912+
"`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
913+
)
914+
915+
super().enable_sequential_cpu_offload(*args, **kwargs)
916+
829917
@torch.no_grad()
830918
@replace_example_docstring(EXAMPLE_DOC_STRING)
831919
def __call__(
@@ -853,6 +941,8 @@ def __call__(
853941
negative_prompt_embeds: Optional[torch.Tensor] = None,
854942
pooled_prompt_embeds: Optional[torch.Tensor] = None,
855943
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
944+
ip_adapter_image: Optional[PipelineImageInput] = None,
945+
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
856946
output_type: Optional[str] = "pil",
857947
return_dict: bool = True,
858948
clip_skip: Optional[int] = None,
@@ -890,9 +980,9 @@ def __call__(
890980
mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
891981
`Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
892982
latents tensor will ge generated by `mask_image`.
893-
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
983+
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
894984
The height in pixels of the generated image. This is set to 1024 by default for the best results.
895-
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
985+
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
896986
The width in pixels of the generated image. This is set to 1024 by default for the best results.
897987
padding_mask_crop (`int`, *optional*, defaults to `None`):
898988
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
@@ -953,6 +1043,12 @@ def __call__(
9531043
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
9541044
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
9551045
input argument.
1046+
ip_adapter_image (`PipelineImageInput`, *optional*):
1047+
Optional image input to work with IP Adapters.
1048+
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
1049+
Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
1050+
emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
1051+
`True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
9561052
output_type (`str`, *optional*, defaults to `"pil"`):
9571053
The output format of the generate image. Choose between
9581054
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -1160,7 +1256,22 @@ def __call__(
11601256
f"The transformer {self.transformer.__class__} should have 16 input channels or 33 input channels, not {self.transformer.config.in_channels}."
11611257
)
11621258

1163-
# 7. Denoising loop
1259+
# 7. Prepare image embeddings
1260+
if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
1261+
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
1262+
ip_adapter_image,
1263+
ip_adapter_image_embeds,
1264+
device,
1265+
batch_size * num_images_per_prompt,
1266+
self.do_classifier_free_guidance,
1267+
)
1268+
1269+
if self.joint_attention_kwargs is None:
1270+
self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
1271+
else:
1272+
self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
1273+
1274+
# 8. Denoising loop
11641275
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
11651276
self._num_timesteps = len(timesteps)
11661277
with self.progress_bar(total=num_inference_steps) as progress_bar:

tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def get_dummy_components(self):
106106
"tokenizer_3": tokenizer_3,
107107
"transformer": transformer,
108108
"vae": vae,
109+
"image_encoder": None,
110+
"feature_extractor": None,
109111
}
110112

111113
def get_dummy_inputs(self, device, seed=0):

0 commit comments

Comments
 (0)