Skip to content

Commit ca61287

Browse files
Fix IP Adapter Support for SAG Pipeline (#7260)
* fix ip adapter support * Update sag pipelines tests, adjust sag pipeline to pass tests --------- Co-authored-by: YiYi Xu <[email protected]>
1 parent f0c8156 commit ca61287

File tree

2 files changed

+70
-9
lines changed

2 files changed

+70
-9
lines changed

src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,40 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
401401

402402
return image_embeds, uncond_image_embeds
403403

404+
def prepare_ip_adapter_image_embeds(
405+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
406+
):
407+
if ip_adapter_image_embeds is None:
408+
if not isinstance(ip_adapter_image, list):
409+
ip_adapter_image = [ip_adapter_image]
410+
411+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
412+
raise ValueError(
413+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
414+
)
415+
416+
image_embeds = []
417+
for single_ip_adapter_image, image_proj_layer in zip(
418+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
419+
):
420+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
421+
single_image_embeds, single_negative_image_embeds = self.encode_image(
422+
single_ip_adapter_image, device, 1, output_hidden_state
423+
)
424+
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
425+
single_negative_image_embeds = torch.stack(
426+
[single_negative_image_embeds] * num_images_per_prompt, dim=0
427+
)
428+
429+
if do_classifier_free_guidance:
430+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
431+
single_image_embeds = single_image_embeds.to(device)
432+
433+
image_embeds.append(single_image_embeds)
434+
else:
435+
image_embeds = ip_adapter_image_embeds
436+
return image_embeds
437+
404438
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
405439
def run_safety_checker(self, image, device, dtype):
406440
if self.safety_checker is None:
@@ -535,6 +569,7 @@ def __call__(
535569
prompt_embeds: Optional[torch.FloatTensor] = None,
536570
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
537571
ip_adapter_image: Optional[PipelineImageInput] = None,
572+
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
538573
output_type: Optional[str] = "pil",
539574
return_dict: bool = True,
540575
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -583,6 +618,9 @@ def __call__(
583618
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
584619
ip_adapter_image: (`PipelineImageInput`, *optional*):
585620
Optional image input to work with IP Adapters.
621+
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
622+
Pre-generated image embeddings for IP-Adapter. If not
623+
provided, embeddings are computed from the `ip_adapter_image` input argument.
586624
output_type (`str`, *optional*, defaults to `"pil"`):
587625
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
588626
return_dict (`bool`, *optional*, defaults to `True`):
@@ -636,13 +674,24 @@ def __call__(
636674
# `sag_scale = 0` means no self-attention guidance
637675
do_self_attention_guidance = sag_scale > 0.0
638676

639-
if ip_adapter_image is not None:
640-
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
641-
image_embeds, negative_image_embeds = self.encode_image(
642-
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
677+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
678+
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
679+
ip_adapter_image,
680+
ip_adapter_image_embeds,
681+
device,
682+
batch_size * num_images_per_prompt,
683+
do_classifier_free_guidance,
643684
)
685+
644686
if do_classifier_free_guidance:
645-
image_embeds = torch.cat([negative_image_embeds, image_embeds])
687+
image_embeds = []
688+
negative_image_embeds = []
689+
for tmp_image_embeds in ip_adapter_image_embeds:
690+
single_negative_image_embeds, single_image_embeds = tmp_image_embeds.chunk(2)
691+
image_embeds.append(single_image_embeds)
692+
negative_image_embeds.append(single_negative_image_embeds)
693+
else:
694+
image_embeds = ip_adapter_image_embeds
646695

647696
# 3. Encode input prompt
648697
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
@@ -687,8 +736,18 @@ def __call__(
687736
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
688737

689738
# 6.1 Add image embeds for IP-Adapter
690-
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
691-
added_uncond_kwargs = {"image_embeds": negative_image_embeds} if ip_adapter_image is not None else None
739+
added_cond_kwargs = (
740+
{"image_embeds": image_embeds}
741+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
742+
else None
743+
)
744+
745+
if do_classifier_free_guidance:
746+
added_uncond_kwargs = (
747+
{"image_embeds": negative_image_embeds}
748+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
749+
else None
750+
)
692751

693752
# 7. Denoising loop
694753
store_processor = CrossAttnStoreProcessor()

tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@
3232
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
3333

3434
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
35-
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
35+
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
3636

3737

3838
enable_full_determinism()
3939

4040

41-
class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
41+
class StableDiffusionSAGPipelineFastTests(
42+
IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
43+
):
4244
pipeline_class = StableDiffusionSAGPipeline
4345
params = TEXT_TO_IMAGE_PARAMS
4446
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS

0 commit comments

Comments
 (0)