Skip to content

Commit 3e8b632

Browse files
antoine-scenariopatrickvonplatensayakpaul
authored
Add IP-Adapter to StableDiffusionXLControlNetImg2ImgPipeline (#6293)
* add IP-Adapter to StableDiffusionXLControlNetImg2ImgPipeline Update src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py Co-authored-by: YiYi Xu <[email protected]> fix tests * fix failing test --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent dd4459a commit 3e8b632

File tree

2 files changed

+72
-7
lines changed

2 files changed

+72
-7
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,23 @@
2020
import PIL.Image
2121
import torch
2222
import torch.nn.functional as F
23-
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
23+
from transformers import (
24+
CLIPImageProcessor,
25+
CLIPTextModel,
26+
CLIPTextModelWithProjection,
27+
CLIPTokenizer,
28+
CLIPVisionModelWithProjection,
29+
)
2430

2531
from diffusers.utils.import_utils import is_invisible_watermark_available
2632

2733
from ...image_processor import PipelineImageInput, VaeImageProcessor
28-
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
29-
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
34+
from ...loaders import (
35+
IPAdapterMixin,
36+
StableDiffusionXLLoraLoaderMixin,
37+
TextualInversionLoaderMixin,
38+
)
39+
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
3040
from ...models.attention_processor import (
3141
AttnProcessor2_0,
3242
LoRAAttnProcessor2_0,
@@ -147,7 +157,7 @@ def retrieve_latents(
147157

148158

149159
class StableDiffusionXLControlNetImg2ImgPipeline(
150-
DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin
160+
DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin
151161
):
152162
r"""
153163
Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance.
@@ -159,6 +169,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
159169
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
160170
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
161171
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
172+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
162173
163174
Args:
164175
vae ([`AutoencoderKL`]):
@@ -197,10 +208,19 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
197208
Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
198209
watermark output images. If not defined, it will default to True if the package is installed, otherwise no
199210
watermarker will be used.
211+
feature_extractor ([`~transformers.CLIPImageProcessor`]):
212+
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
200213
"""
201214

202-
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
203-
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
215+
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
216+
_optional_components = [
217+
"tokenizer",
218+
"tokenizer_2",
219+
"text_encoder",
220+
"text_encoder_2",
221+
"feature_extractor",
222+
"image_encoder",
223+
]
204224
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
205225

206226
def __init__(
@@ -216,6 +236,8 @@ def __init__(
216236
requires_aesthetics_score: bool = False,
217237
force_zeros_for_empty_prompt: bool = True,
218238
add_watermarker: Optional[bool] = None,
239+
feature_extractor: CLIPImageProcessor = None,
240+
image_encoder: CLIPVisionModelWithProjection = None,
219241
):
220242
super().__init__()
221243

@@ -231,6 +253,8 @@ def __init__(
231253
unet=unet,
232254
controlnet=controlnet,
233255
scheduler=scheduler,
256+
feature_extractor=feature_extractor,
257+
image_encoder=image_encoder,
234258
)
235259
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
236260
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
@@ -515,6 +539,31 @@ def encode_prompt(
515539

516540
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
517541

542+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
543+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
544+
dtype = next(self.image_encoder.parameters()).dtype
545+
546+
if not isinstance(image, torch.Tensor):
547+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
548+
549+
image = image.to(device=device, dtype=dtype)
550+
if output_hidden_states:
551+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
552+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
553+
uncond_image_enc_hidden_states = self.image_encoder(
554+
torch.zeros_like(image), output_hidden_states=True
555+
).hidden_states[-2]
556+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
557+
num_images_per_prompt, dim=0
558+
)
559+
return image_enc_hidden_states, uncond_image_enc_hidden_states
560+
else:
561+
image_embeds = self.image_encoder(image).image_embeds
562+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
563+
uncond_image_embeds = torch.zeros_like(image_embeds)
564+
565+
return image_embeds, uncond_image_embeds
566+
518567
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
519568
def prepare_extra_step_kwargs(self, generator, eta):
520569
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -1011,6 +1060,7 @@ def __call__(
10111060
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
10121061
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
10131062
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1063+
ip_adapter_image: Optional[PipelineImageInput] = None,
10141064
output_type: Optional[str] = "pil",
10151065
return_dict: bool = True,
10161066
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -1109,6 +1159,7 @@ def __call__(
11091159
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
11101160
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
11111161
input argument.
1162+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
11121163
output_type (`str`, *optional*, defaults to `"pil"`):
11131164
The output format of the generate image. Choose between
11141165
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -1262,7 +1313,7 @@ def __call__(
12621313
)
12631314
guess_mode = guess_mode or global_pool_conditions
12641315

1265-
# 3. Encode input prompt
1316+
# 3.1. Encode input prompt
12661317
text_encoder_lora_scale = (
12671318
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
12681319
)
@@ -1287,6 +1338,15 @@ def __call__(
12871338
clip_skip=self.clip_skip,
12881339
)
12891340

1341+
# 3.2 Encode ip_adapter_image
1342+
if ip_adapter_image is not None:
1343+
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
1344+
image_embeds, negative_image_embeds = self.encode_image(
1345+
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
1346+
)
1347+
if self.do_classifier_free_guidance:
1348+
image_embeds = torch.cat([negative_image_embeds, image_embeds])
1349+
12901350
# 4. Prepare image and controlnet_conditioning_image
12911351
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
12921352

@@ -1449,6 +1509,9 @@ def __call__(
14491509
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
14501510
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
14511511

1512+
if ip_adapter_image is not None:
1513+
added_cond_kwargs["image_embeds"] = image_embeds
1514+
14521515
# predict the noise residual
14531516
noise_pred = self.unet(
14541517
latent_model_input,

tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ def get_dummy_components(self, skip_first_text_encoder=False):
136136
"tokenizer": tokenizer if not skip_first_text_encoder else None,
137137
"text_encoder_2": text_encoder_2,
138138
"tokenizer_2": tokenizer_2,
139+
"image_encoder": None,
140+
"feature_extractor": None,
139141
}
140142
return components
141143

0 commit comments

Comments
 (0)