20
20
import PIL .Image
21
21
import torch
22
22
import torch .nn .functional as F
23
- from transformers import CLIPTextModel , CLIPTextModelWithProjection , CLIPTokenizer
23
+ from transformers import (
24
+ CLIPImageProcessor ,
25
+ CLIPTextModel ,
26
+ CLIPTextModelWithProjection ,
27
+ CLIPTokenizer ,
28
+ CLIPVisionModelWithProjection ,
29
+ )
24
30
25
31
from diffusers .utils .import_utils import is_invisible_watermark_available
26
32
27
33
from ...image_processor import PipelineImageInput , VaeImageProcessor
28
- from ...loaders import StableDiffusionXLLoraLoaderMixin , TextualInversionLoaderMixin
29
- from ...models import AutoencoderKL , ControlNetModel , UNet2DConditionModel
34
+ from ...loaders import (
35
+ IPAdapterMixin ,
36
+ StableDiffusionXLLoraLoaderMixin ,
37
+ TextualInversionLoaderMixin ,
38
+ )
39
+ from ...models import AutoencoderKL , ControlNetModel , ImageProjection , UNet2DConditionModel
30
40
from ...models .attention_processor import (
31
41
AttnProcessor2_0 ,
32
42
LoRAAttnProcessor2_0 ,
@@ -147,7 +157,7 @@ def retrieve_latents(
147
157
148
158
149
159
class StableDiffusionXLControlNetImg2ImgPipeline (
150
- DiffusionPipeline , TextualInversionLoaderMixin , StableDiffusionXLLoraLoaderMixin
160
+ DiffusionPipeline , TextualInversionLoaderMixin , StableDiffusionXLLoraLoaderMixin , IPAdapterMixin
151
161
):
152
162
r"""
153
163
Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance.
@@ -159,6 +169,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
159
169
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
160
170
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
161
171
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
172
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
162
173
163
174
Args:
164
175
vae ([`AutoencoderKL`]):
@@ -197,10 +208,19 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
197
208
Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
198
209
watermark output images. If not defined, it will default to True if the package is installed, otherwise no
199
210
watermarker will be used.
211
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
212
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
200
213
"""
201
214
202
- model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
203
- _optional_components = ["tokenizer" , "tokenizer_2" , "text_encoder" , "text_encoder_2" ]
215
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
216
+ _optional_components = [
217
+ "tokenizer" ,
218
+ "tokenizer_2" ,
219
+ "text_encoder" ,
220
+ "text_encoder_2" ,
221
+ "feature_extractor" ,
222
+ "image_encoder" ,
223
+ ]
204
224
_callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" ]
205
225
206
226
def __init__ (
@@ -216,6 +236,8 @@ def __init__(
216
236
requires_aesthetics_score : bool = False ,
217
237
force_zeros_for_empty_prompt : bool = True ,
218
238
add_watermarker : Optional [bool ] = None ,
239
+ feature_extractor : CLIPImageProcessor = None ,
240
+ image_encoder : CLIPVisionModelWithProjection = None ,
219
241
):
220
242
super ().__init__ ()
221
243
@@ -231,6 +253,8 @@ def __init__(
231
253
unet = unet ,
232
254
controlnet = controlnet ,
233
255
scheduler = scheduler ,
256
+ feature_extractor = feature_extractor ,
257
+ image_encoder = image_encoder ,
234
258
)
235
259
self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 )
236
260
self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor , do_convert_rgb = True )
@@ -515,6 +539,31 @@ def encode_prompt(
515
539
516
540
return prompt_embeds , negative_prompt_embeds , pooled_prompt_embeds , negative_pooled_prompt_embeds
517
541
542
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
543
+ def encode_image (self , image , device , num_images_per_prompt , output_hidden_states = None ):
544
+ dtype = next (self .image_encoder .parameters ()).dtype
545
+
546
+ if not isinstance (image , torch .Tensor ):
547
+ image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
548
+
549
+ image = image .to (device = device , dtype = dtype )
550
+ if output_hidden_states :
551
+ image_enc_hidden_states = self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
552
+ image_enc_hidden_states = image_enc_hidden_states .repeat_interleave (num_images_per_prompt , dim = 0 )
553
+ uncond_image_enc_hidden_states = self .image_encoder (
554
+ torch .zeros_like (image ), output_hidden_states = True
555
+ ).hidden_states [- 2 ]
556
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states .repeat_interleave (
557
+ num_images_per_prompt , dim = 0
558
+ )
559
+ return image_enc_hidden_states , uncond_image_enc_hidden_states
560
+ else :
561
+ image_embeds = self .image_encoder (image ).image_embeds
562
+ image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
563
+ uncond_image_embeds = torch .zeros_like (image_embeds )
564
+
565
+ return image_embeds , uncond_image_embeds
566
+
518
567
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
519
568
def prepare_extra_step_kwargs (self , generator , eta ):
520
569
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -1011,6 +1060,7 @@ def __call__(
1011
1060
negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
1012
1061
pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
1013
1062
negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
1063
+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
1014
1064
output_type : Optional [str ] = "pil" ,
1015
1065
return_dict : bool = True ,
1016
1066
cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -1109,6 +1159,7 @@ def __call__(
1109
1159
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1110
1160
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1111
1161
input argument.
1162
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1112
1163
output_type (`str`, *optional*, defaults to `"pil"`):
1113
1164
The output format of the generate image. Choose between
1114
1165
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -1262,7 +1313,7 @@ def __call__(
1262
1313
)
1263
1314
guess_mode = guess_mode or global_pool_conditions
1264
1315
1265
- # 3. Encode input prompt
1316
+ # 3.1. Encode input prompt
1266
1317
text_encoder_lora_scale = (
1267
1318
self .cross_attention_kwargs .get ("scale" , None ) if self .cross_attention_kwargs is not None else None
1268
1319
)
@@ -1287,6 +1338,15 @@ def __call__(
1287
1338
clip_skip = self .clip_skip ,
1288
1339
)
1289
1340
1341
+ # 3.2 Encode ip_adapter_image
1342
+ if ip_adapter_image is not None :
1343
+ output_hidden_state = False if isinstance (self .unet .encoder_hid_proj , ImageProjection ) else True
1344
+ image_embeds , negative_image_embeds = self .encode_image (
1345
+ ip_adapter_image , device , num_images_per_prompt , output_hidden_state
1346
+ )
1347
+ if self .do_classifier_free_guidance :
1348
+ image_embeds = torch .cat ([negative_image_embeds , image_embeds ])
1349
+
1290
1350
# 4. Prepare image and controlnet_conditioning_image
1291
1351
image = self .image_processor .preprocess (image , height = height , width = width ).to (dtype = torch .float32 )
1292
1352
@@ -1449,6 +1509,9 @@ def __call__(
1449
1509
down_block_res_samples = [torch .cat ([torch .zeros_like (d ), d ]) for d in down_block_res_samples ]
1450
1510
mid_block_res_sample = torch .cat ([torch .zeros_like (mid_block_res_sample ), mid_block_res_sample ])
1451
1511
1512
+ if ip_adapter_image is not None :
1513
+ added_cond_kwargs ["image_embeds" ] = image_embeds
1514
+
1452
1515
# predict the noise residual
1453
1516
noise_pred = self .unet (
1454
1517
latent_model_input ,
0 commit comments