Skip to content

Commit ef1f486

Browse files
committed
Added support for IP-Adapter
1 parent 9f06a0d commit ef1f486

File tree

2 files changed

+119
-6
lines changed

2 files changed

+119
-6
lines changed

src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py

Lines changed: 117 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@
1717

1818
import torch
1919
from transformers import (
20+
BaseImageProcessor,
2021
CLIPTextModelWithProjection,
2122
CLIPTokenizer,
23+
PreTrainedModel,
2224
T5EncoderModel,
2325
T5TokenizerFast,
2426
)
2527

2628
from ...image_processor import PipelineImageInput, VaeImageProcessor
27-
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
29+
from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
2830
from ...models.autoencoders import AutoencoderKL
2931
from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
3032
from ...models.transformers import SD3Transformer2DModel
@@ -159,7 +161,9 @@ def retrieve_timesteps(
159161
return timesteps, num_inference_steps
160162

161163

162-
class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
164+
class StableDiffusion3ControlNetInpaintingPipeline(
165+
DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin
166+
):
163167
r"""
164168
Args:
165169
transformer ([`SD3Transformer2DModel`]):
@@ -192,13 +196,17 @@ class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoa
192196
Tokenizer of class
193197
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
194198
controlnet ([`SD3ControlNetModel`] or `List[SD3ControlNetModel]` or [`SD3MultiControlNetModel`]):
195-
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
199+
Provides additional conditioning to the `transformer` during the denoising process. If you set multiple
196200
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
197201
additional conditioning.
202+
image_encoder (`PreTrainedModel`, *optional*):
203+
Pre-trained Vision Model for IP Adapter.
204+
feature_extractor (`BaseImageProcessor`, *optional*):
205+
Image processor for IP Adapter.
198206
"""
199207

200-
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
201-
_optional_components = []
208+
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
209+
_optional_components = ["image_encoder", "feature_extractor"]
202210
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
203211

204212
def __init__(
@@ -215,6 +223,8 @@ def __init__(
215223
controlnet: Union[
216224
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
217225
],
226+
image_encoder: PreTrainedModel = None,
227+
feature_extractor: BaseImageProcessor = None,
218228
):
219229
super().__init__()
220230

@@ -229,6 +239,8 @@ def __init__(
229239
transformer=transformer,
230240
scheduler=scheduler,
231241
controlnet=controlnet,
242+
image_encoder=image_encoder,
243+
feature_extractor=feature_extractor,
232244
)
233245
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
234246
self.image_processor = VaeImageProcessor(
@@ -775,6 +787,82 @@ def num_timesteps(self):
775787
def interrupt(self):
776788
return self._interrupt
777789

790+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
791+
def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
792+
"""Encodes the given image into a feature representation using a pre-trained image encoder.
793+
794+
Args:
795+
image (`PipelineImageInput`):
796+
Input image to be encoded.
797+
device: (`torch.device`):
798+
Torch device.
799+
Returns:
800+
`torch.Tensor`: The encoded image feature representation.
801+
"""
802+
if not isinstance(image, torch.Tensor):
803+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
804+
805+
image = image.to(device=device, dtype=self.dtype)
806+
807+
return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
808+
809+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
810+
def prepare_ip_adapter_image_embeds(
811+
self,
812+
ip_adapter_image: Optional[PipelineImageInput] = None,
813+
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
814+
device: Optional[torch.device] = None,
815+
num_images_per_prompt: int = 1,
816+
do_classifier_free_guidance: bool = True,
817+
) -> torch.Tensor:
818+
"""Prepares image embeddings for use in the IP-Adapter.
819+
Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
820+
821+
Args:
822+
ip_adapter_image (`PipelineImageInput`, *optional*):
823+
The input image to extract features from for IP-Adapter.
824+
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
825+
Precomputed image embeddings.
826+
device: (`torch.device`, *optional*):
827+
Torch device.
828+
num_images_per_prompt (`int`, defaults to 1):
829+
Number of images that should be generated per prompt.
830+
do_classifier_free_guidance (`bool`, defaults to True):
831+
Whether to use classifier free guidance or not.
832+
"""
833+
device = device or self._execution_device
834+
835+
if ip_adapter_image_embeds is not None:
836+
if do_classifier_free_guidance:
837+
single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
838+
else:
839+
single_image_embeds = ip_adapter_image_embeds
840+
elif ip_adapter_image is not None:
841+
single_image_embeds = self.encode_image(ip_adapter_image, device)
842+
if do_classifier_free_guidance:
843+
single_negative_image_embeds = torch.zeros_like(single_image_embeds)
844+
else:
845+
raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
846+
847+
image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
848+
849+
if do_classifier_free_guidance:
850+
negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
851+
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
852+
853+
return image_embeds.to(device=device)
854+
855+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
856+
def enable_sequential_cpu_offload(self, *args, **kwargs):
857+
if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
858+
logger.warning(
859+
"`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
860+
"`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
861+
"`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
862+
)
863+
864+
super().enable_sequential_cpu_offload(*args, **kwargs)
865+
778866
@torch.no_grad()
779867
@replace_example_docstring(EXAMPLE_DOC_STRING)
780868
def __call__(
@@ -803,6 +891,8 @@ def __call__(
803891
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
804892
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
805893
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
894+
ip_adapter_image: Optional[PipelineImageInput] = None,
895+
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
806896
output_type: Optional[str] = "pil",
807897
return_dict: bool = True,
808898
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -896,6 +986,12 @@ def __call__(
896986
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
897987
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
898988
input argument.
989+
ip_adapter_image (`PipelineImageInput`, *optional*):
990+
Optional image input to work with IP Adapters.
991+
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
992+
Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
993+
emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
994+
`True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
899995
output_type (`str`, *optional*, defaults to `"pil"`):
900996
The output format of the generate image. Choose between
901997
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -1057,7 +1153,22 @@ def __call__(
10571153
]
10581154
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, SD3ControlNetModel) else keeps)
10591155

1060-
# 7. Denoising loop
1156+
# 7. Prepare image embeddings
1157+
if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
1158+
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
1159+
ip_adapter_image,
1160+
ip_adapter_image_embeds,
1161+
device,
1162+
batch_size * num_images_per_prompt,
1163+
self.do_classifier_free_guidance,
1164+
)
1165+
1166+
if self.joint_attention_kwargs is None:
1167+
self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
1168+
else:
1169+
self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
1170+
1171+
# 8. Denoising loop
10611172
with self.progress_bar(total=num_inference_steps) as progress_bar:
10621173
for i, t in enumerate(timesteps):
10631174
if self.interrupt:

tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ def get_dummy_components(self):
137137
"transformer": transformer,
138138
"vae": vae,
139139
"controlnet": controlnet,
140+
"image_encoder": None,
141+
"feature_extractor": None,
140142
}
141143

142144
def get_dummy_inputs(self, device, seed=0):

0 commit comments

Comments
 (0)