Skip to content

Commit da843b3

Browse files
.load_ip_adapter in StableDiffusionXLAdapterPipeline (#6246)
* Added testing notebook and .load_ip_adapter to XLAdapterPipeline * Added annotations * deleted testing notebook * Update src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py Co-authored-by: YiYi Xu <[email protected]> * code clean up * Add feature_extractor and image_encoder to components --------- Co-authored-by: YiYi Xu <[email protected]>
1 parent 17cece0 commit da843b3

File tree

2 files changed

+84
-17
lines changed

2 files changed

+84
-17
lines changed

src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py

Lines changed: 80 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,22 @@
1818
import numpy as np
1919
import PIL.Image
2020
import torch
21-
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
21+
from transformers import (
22+
CLIPImageProcessor,
23+
CLIPTextModel,
24+
CLIPTextModelWithProjection,
25+
CLIPTokenizer,
26+
CLIPVisionModelWithProjection,
27+
)
2228

23-
from ...image_processor import VaeImageProcessor
24-
from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
25-
from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel
29+
from ...image_processor import PipelineImageInput, VaeImageProcessor
30+
from ...loaders import (
31+
FromSingleFileMixin,
32+
IPAdapterMixin,
33+
StableDiffusionXLLoraLoaderMixin,
34+
TextualInversionLoaderMixin,
35+
)
36+
from ...models import AutoencoderKL, ImageProjection, MultiAdapter, T2IAdapter, UNet2DConditionModel
2637
from ...models.attention_processor import (
2738
AttnProcessor2_0,
2839
LoRAAttnProcessor2_0,
@@ -169,7 +180,11 @@ def retrieve_timesteps(
169180

170181

171182
class StableDiffusionXLAdapterPipeline(
172-
DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
183+
DiffusionPipeline,
184+
TextualInversionLoaderMixin,
185+
StableDiffusionXLLoraLoaderMixin,
186+
IPAdapterMixin,
187+
FromSingleFileMixin,
173188
):
174189
r"""
175190
Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
@@ -183,6 +198,7 @@ class StableDiffusionXLAdapterPipeline(
183198
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
184199
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
185200
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
201+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
186202
187203
Args:
188204
adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`):
@@ -211,8 +227,15 @@ class StableDiffusionXLAdapterPipeline(
211227
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
212228
"""
213229

214-
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
215-
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
230+
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
231+
_optional_components = [
232+
"tokenizer",
233+
"tokenizer_2",
234+
"text_encoder",
235+
"text_encoder_2",
236+
"feature_extractor",
237+
"image_encoder",
238+
]
216239

217240
def __init__(
218241
self,
@@ -225,6 +248,8 @@ def __init__(
225248
adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]],
226249
scheduler: KarrasDiffusionSchedulers,
227250
force_zeros_for_empty_prompt: bool = True,
251+
feature_extractor: CLIPImageProcessor = None,
252+
image_encoder: CLIPVisionModelWithProjection = None,
228253
):
229254
super().__init__()
230255

@@ -237,6 +262,8 @@ def __init__(
237262
unet=unet,
238263
adapter=adapter,
239264
scheduler=scheduler,
265+
feature_extractor=feature_extractor,
266+
image_encoder=image_encoder,
240267
)
241268
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
242269
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
@@ -511,6 +538,31 @@ def encode_prompt(
511538

512539
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
513540

541+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
542+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
543+
dtype = next(self.image_encoder.parameters()).dtype
544+
545+
if not isinstance(image, torch.Tensor):
546+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
547+
548+
image = image.to(device=device, dtype=dtype)
549+
if output_hidden_states:
550+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
551+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
552+
uncond_image_enc_hidden_states = self.image_encoder(
553+
torch.zeros_like(image), output_hidden_states=True
554+
).hidden_states[-2]
555+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
556+
num_images_per_prompt, dim=0
557+
)
558+
return image_enc_hidden_states, uncond_image_enc_hidden_states
559+
else:
560+
image_embeds = self.image_encoder(image).image_embeds
561+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
562+
uncond_image_embeds = torch.zeros_like(image_embeds)
563+
564+
return image_embeds, uncond_image_embeds
565+
514566
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
515567
def prepare_extra_step_kwargs(self, generator, eta):
516568
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -768,7 +820,7 @@ def __call__(
768820
self,
769821
prompt: Union[str, List[str]] = None,
770822
prompt_2: Optional[Union[str, List[str]]] = None,
771-
image: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
823+
image: PipelineImageInput = None,
772824
height: Optional[int] = None,
773825
width: Optional[int] = None,
774826
num_inference_steps: int = 50,
@@ -785,6 +837,7 @@ def __call__(
785837
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
786838
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
787839
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
840+
ip_adapter_image: Optional[PipelineImageInput] = None,
788841
output_type: Optional[str] = "pil",
789842
return_dict: bool = True,
790843
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -876,6 +929,7 @@ def __call__(
876929
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
877930
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
878931
input argument.
932+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
879933
output_type (`str`, *optional*, defaults to `"pil"`):
880934
The output format of the generate image. Choose between
881935
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -991,7 +1045,7 @@ def __call__(
9911045

9921046
device = self._execution_device
9931047

994-
# 3. Encode input prompt
1048+
# 3.1 Encode input prompt
9951049
(
9961050
prompt_embeds,
9971051
negative_prompt_embeds,
@@ -1012,6 +1066,15 @@ def __call__(
10121066
clip_skip=clip_skip,
10131067
)
10141068

1069+
# 3.2 Encode ip_adapter_image
1070+
if ip_adapter_image is not None:
1071+
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
1072+
image_embeds, negative_image_embeds = self.encode_image(
1073+
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
1074+
)
1075+
if self.do_classifier_free_guidance:
1076+
image_embeds = torch.cat([negative_image_embeds, image_embeds])
1077+
10151078
# 4. Prepare timesteps
10161079
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
10171080

@@ -1028,10 +1091,10 @@ def __call__(
10281091
latents,
10291092
)
10301093

1031-
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1094+
# 6.1 Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
10321095
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
10331096

1034-
# 6.5 Optionally get Guidance Scale Embedding
1097+
# 6.2 Optionally get Guidance Scale Embedding
10351098
timestep_cond = None
10361099
if self.unet.config.time_cond_proj_dim is not None:
10371100
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
@@ -1090,8 +1153,7 @@ def __call__(
10901153

10911154
# 8. Denoising loop
10921155
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1093-
1094-
# 7.1 Apply denoising_end
1156+
# Apply denoising_end
10951157
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
10961158
discrete_timestep_cutoff = int(
10971159
round(
@@ -1109,9 +1171,12 @@ def __call__(
11091171

11101172
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
11111173

1112-
# predict the noise residual
11131174
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
11141175

1176+
if ip_adapter_image is not None:
1177+
added_cond_kwargs["image_embeds"] = image_embeds
1178+
1179+
# predict the noise residual
11151180
if i < int(num_inference_steps * adapter_conditioning_factor):
11161181
down_intrablock_additional_residuals = [state.clone() for state in adapter_state]
11171182
else:
@@ -1123,9 +1188,9 @@ def __call__(
11231188
encoder_hidden_states=prompt_embeds,
11241189
timestep_cond=timestep_cond,
11251190
cross_attention_kwargs=cross_attention_kwargs,
1191+
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
11261192
added_cond_kwargs=added_cond_kwargs,
11271193
return_dict=False,
1128-
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
11291194
)[0]
11301195

11311196
# perform guidance

tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ def get_dummy_components(self, adapter_type="full_adapter_xl", time_cond_proj_di
159159
"text_encoder_2": text_encoder_2,
160160
"tokenizer_2": tokenizer_2,
161161
# "safety_checker": None,
162-
# "feature_extractor": None,
162+
"feature_extractor": None,
163+
"image_encoder": None,
163164
}
164165
return components
165166

@@ -265,7 +266,8 @@ def get_dummy_components_with_full_downscaling(self, adapter_type="full_adapter_
265266
"text_encoder_2": text_encoder_2,
266267
"tokenizer_2": tokenizer_2,
267268
# "safety_checker": None,
268-
# "feature_extractor": None,
269+
"feature_extractor": None,
270+
"image_encoder": None,
269271
}
270272
return components
271273

0 commit comments

Comments
 (0)