18
18
import numpy as np
19
19
import PIL .Image
20
20
import torch
21
- from transformers import CLIPTextModel , CLIPTextModelWithProjection , CLIPTokenizer
21
+ from transformers import (
22
+ CLIPImageProcessor ,
23
+ CLIPTextModel ,
24
+ CLIPTextModelWithProjection ,
25
+ CLIPTokenizer ,
26
+ CLIPVisionModelWithProjection ,
27
+ )
22
28
23
- from ...image_processor import VaeImageProcessor
24
- from ...loaders import FromSingleFileMixin , StableDiffusionXLLoraLoaderMixin , TextualInversionLoaderMixin
25
- from ...models import AutoencoderKL , MultiAdapter , T2IAdapter , UNet2DConditionModel
29
+ from ...image_processor import PipelineImageInput , VaeImageProcessor
30
+ from ...loaders import (
31
+ FromSingleFileMixin ,
32
+ IPAdapterMixin ,
33
+ StableDiffusionXLLoraLoaderMixin ,
34
+ TextualInversionLoaderMixin ,
35
+ )
36
+ from ...models import AutoencoderKL , ImageProjection , MultiAdapter , T2IAdapter , UNet2DConditionModel
26
37
from ...models .attention_processor import (
27
38
AttnProcessor2_0 ,
28
39
LoRAAttnProcessor2_0 ,
@@ -169,7 +180,11 @@ def retrieve_timesteps(
169
180
170
181
171
182
class StableDiffusionXLAdapterPipeline (
172
- DiffusionPipeline , FromSingleFileMixin , StableDiffusionXLLoraLoaderMixin , TextualInversionLoaderMixin
183
+ DiffusionPipeline ,
184
+ TextualInversionLoaderMixin ,
185
+ StableDiffusionXLLoraLoaderMixin ,
186
+ IPAdapterMixin ,
187
+ FromSingleFileMixin ,
173
188
):
174
189
r"""
175
190
Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
@@ -183,6 +198,7 @@ class StableDiffusionXLAdapterPipeline(
183
198
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
184
199
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
185
200
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
201
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
186
202
187
203
Args:
188
204
adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`):
@@ -211,8 +227,15 @@ class StableDiffusionXLAdapterPipeline(
211
227
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
212
228
"""
213
229
214
- model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
215
- _optional_components = ["tokenizer" , "tokenizer_2" , "text_encoder" , "text_encoder_2" ]
230
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
231
+ _optional_components = [
232
+ "tokenizer" ,
233
+ "tokenizer_2" ,
234
+ "text_encoder" ,
235
+ "text_encoder_2" ,
236
+ "feature_extractor" ,
237
+ "image_encoder" ,
238
+ ]
216
239
217
240
def __init__ (
218
241
self ,
@@ -225,6 +248,8 @@ def __init__(
225
248
adapter : Union [T2IAdapter , MultiAdapter , List [T2IAdapter ]],
226
249
scheduler : KarrasDiffusionSchedulers ,
227
250
force_zeros_for_empty_prompt : bool = True ,
251
+ feature_extractor : CLIPImageProcessor = None ,
252
+ image_encoder : CLIPVisionModelWithProjection = None ,
228
253
):
229
254
super ().__init__ ()
230
255
@@ -237,6 +262,8 @@ def __init__(
237
262
unet = unet ,
238
263
adapter = adapter ,
239
264
scheduler = scheduler ,
265
+ feature_extractor = feature_extractor ,
266
+ image_encoder = image_encoder ,
240
267
)
241
268
self .register_to_config (force_zeros_for_empty_prompt = force_zeros_for_empty_prompt )
242
269
self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 )
@@ -511,6 +538,31 @@ def encode_prompt(
511
538
512
539
return prompt_embeds , negative_prompt_embeds , pooled_prompt_embeds , negative_pooled_prompt_embeds
513
540
541
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
542
+ def encode_image (self , image , device , num_images_per_prompt , output_hidden_states = None ):
543
+ dtype = next (self .image_encoder .parameters ()).dtype
544
+
545
+ if not isinstance (image , torch .Tensor ):
546
+ image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
547
+
548
+ image = image .to (device = device , dtype = dtype )
549
+ if output_hidden_states :
550
+ image_enc_hidden_states = self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
551
+ image_enc_hidden_states = image_enc_hidden_states .repeat_interleave (num_images_per_prompt , dim = 0 )
552
+ uncond_image_enc_hidden_states = self .image_encoder (
553
+ torch .zeros_like (image ), output_hidden_states = True
554
+ ).hidden_states [- 2 ]
555
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states .repeat_interleave (
556
+ num_images_per_prompt , dim = 0
557
+ )
558
+ return image_enc_hidden_states , uncond_image_enc_hidden_states
559
+ else :
560
+ image_embeds = self .image_encoder (image ).image_embeds
561
+ image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
562
+ uncond_image_embeds = torch .zeros_like (image_embeds )
563
+
564
+ return image_embeds , uncond_image_embeds
565
+
514
566
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
515
567
def prepare_extra_step_kwargs (self , generator , eta ):
516
568
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -768,7 +820,7 @@ def __call__(
768
820
self ,
769
821
prompt : Union [str , List [str ]] = None ,
770
822
prompt_2 : Optional [Union [str , List [str ]]] = None ,
771
- image : Union [ torch . Tensor , PIL . Image . Image , List [ PIL . Image . Image ]] = None ,
823
+ image : PipelineImageInput = None ,
772
824
height : Optional [int ] = None ,
773
825
width : Optional [int ] = None ,
774
826
num_inference_steps : int = 50 ,
@@ -785,6 +837,7 @@ def __call__(
785
837
negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
786
838
pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
787
839
negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
840
+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
788
841
output_type : Optional [str ] = "pil" ,
789
842
return_dict : bool = True ,
790
843
callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
@@ -876,6 +929,7 @@ def __call__(
876
929
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
877
930
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
878
931
input argument.
932
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
879
933
output_type (`str`, *optional*, defaults to `"pil"`):
880
934
The output format of the generate image. Choose between
881
935
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -991,7 +1045,7 @@ def __call__(
991
1045
992
1046
device = self ._execution_device
993
1047
994
- # 3. Encode input prompt
1048
+ # 3.1 Encode input prompt
995
1049
(
996
1050
prompt_embeds ,
997
1051
negative_prompt_embeds ,
@@ -1012,6 +1066,15 @@ def __call__(
1012
1066
clip_skip = clip_skip ,
1013
1067
)
1014
1068
1069
+ # 3.2 Encode ip_adapter_image
1070
+ if ip_adapter_image is not None :
1071
+ output_hidden_state = False if isinstance (self .unet .encoder_hid_proj , ImageProjection ) else True
1072
+ image_embeds , negative_image_embeds = self .encode_image (
1073
+ ip_adapter_image , device , num_images_per_prompt , output_hidden_state
1074
+ )
1075
+ if self .do_classifier_free_guidance :
1076
+ image_embeds = torch .cat ([negative_image_embeds , image_embeds ])
1077
+
1015
1078
# 4. Prepare timesteps
1016
1079
timesteps , num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , timesteps )
1017
1080
@@ -1028,10 +1091,10 @@ def __call__(
1028
1091
latents ,
1029
1092
)
1030
1093
1031
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1094
+ # 6.1 Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1032
1095
extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
1033
1096
1034
- # 6.5 Optionally get Guidance Scale Embedding
1097
+ # 6.2 Optionally get Guidance Scale Embedding
1035
1098
timestep_cond = None
1036
1099
if self .unet .config .time_cond_proj_dim is not None :
1037
1100
guidance_scale_tensor = torch .tensor (self .guidance_scale - 1 ).repeat (batch_size * num_images_per_prompt )
@@ -1090,8 +1153,7 @@ def __call__(
1090
1153
1091
1154
# 8. Denoising loop
1092
1155
num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
1093
-
1094
- # 7.1 Apply denoising_end
1156
+ # Apply denoising_end
1095
1157
if denoising_end is not None and isinstance (denoising_end , float ) and denoising_end > 0 and denoising_end < 1 :
1096
1158
discrete_timestep_cutoff = int (
1097
1159
round (
@@ -1109,9 +1171,12 @@ def __call__(
1109
1171
1110
1172
latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
1111
1173
1112
- # predict the noise residual
1113
1174
added_cond_kwargs = {"text_embeds" : add_text_embeds , "time_ids" : add_time_ids }
1114
1175
1176
+ if ip_adapter_image is not None :
1177
+ added_cond_kwargs ["image_embeds" ] = image_embeds
1178
+
1179
+ # predict the noise residual
1115
1180
if i < int (num_inference_steps * adapter_conditioning_factor ):
1116
1181
down_intrablock_additional_residuals = [state .clone () for state in adapter_state ]
1117
1182
else :
@@ -1123,9 +1188,9 @@ def __call__(
1123
1188
encoder_hidden_states = prompt_embeds ,
1124
1189
timestep_cond = timestep_cond ,
1125
1190
cross_attention_kwargs = cross_attention_kwargs ,
1191
+ down_intrablock_additional_residuals = down_intrablock_additional_residuals ,
1126
1192
added_cond_kwargs = added_cond_kwargs ,
1127
1193
return_dict = False ,
1128
- down_intrablock_additional_residuals = down_intrablock_additional_residuals ,
1129
1194
)[0 ]
1130
1195
1131
1196
# perform guidance
0 commit comments