@@ -401,6 +401,40 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
401
401
402
402
return image_embeds , uncond_image_embeds
403
403
404
+ def prepare_ip_adapter_image_embeds (
405
+ self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt , do_classifier_free_guidance
406
+ ):
407
+ if ip_adapter_image_embeds is None :
408
+ if not isinstance (ip_adapter_image , list ):
409
+ ip_adapter_image = [ip_adapter_image ]
410
+
411
+ if len (ip_adapter_image ) != len (self .unet .encoder_hid_proj .image_projection_layers ):
412
+ raise ValueError (
413
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got { len (ip_adapter_image )} images and { len (self .unet .encoder_hid_proj .image_projection_layers )} IP Adapters."
414
+ )
415
+
416
+ image_embeds = []
417
+ for single_ip_adapter_image , image_proj_layer in zip (
418
+ ip_adapter_image , self .unet .encoder_hid_proj .image_projection_layers
419
+ ):
420
+ output_hidden_state = not isinstance (image_proj_layer , ImageProjection )
421
+ single_image_embeds , single_negative_image_embeds = self .encode_image (
422
+ single_ip_adapter_image , device , 1 , output_hidden_state
423
+ )
424
+ single_image_embeds = torch .stack ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
425
+ single_negative_image_embeds = torch .stack (
426
+ [single_negative_image_embeds ] * num_images_per_prompt , dim = 0
427
+ )
428
+
429
+ if do_classifier_free_guidance :
430
+ single_image_embeds = torch .cat ([single_negative_image_embeds , single_image_embeds ])
431
+ single_image_embeds = single_image_embeds .to (device )
432
+
433
+ image_embeds .append (single_image_embeds )
434
+ else :
435
+ image_embeds = ip_adapter_image_embeds
436
+ return image_embeds
437
+
404
438
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
405
439
def run_safety_checker (self , image , device , dtype ):
406
440
if self .safety_checker is None :
@@ -535,6 +569,7 @@ def __call__(
535
569
prompt_embeds : Optional [torch .FloatTensor ] = None ,
536
570
negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
537
571
ip_adapter_image : Optional [PipelineImageInput ] = None ,
572
+ ip_adapter_image_embeds : Optional [List [torch .FloatTensor ]] = None ,
538
573
output_type : Optional [str ] = "pil" ,
539
574
return_dict : bool = True ,
540
575
callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
@@ -583,6 +618,9 @@ def __call__(
583
618
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
584
619
ip_adapter_image: (`PipelineImageInput`, *optional*):
585
620
Optional image input to work with IP Adapters.
621
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
622
+ Pre-generated image embeddings for IP-Adapter. If not
623
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
586
624
output_type (`str`, *optional*, defaults to `"pil"`):
587
625
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
588
626
return_dict (`bool`, *optional*, defaults to `True`):
@@ -636,13 +674,24 @@ def __call__(
636
674
# `sag_scale = 0` means no self-attention guidance
637
675
do_self_attention_guidance = sag_scale > 0.0
638
676
639
- if ip_adapter_image is not None :
640
- output_hidden_state = False if isinstance (self .unet .encoder_hid_proj , ImageProjection ) else True
641
- image_embeds , negative_image_embeds = self .encode_image (
642
- ip_adapter_image , device , num_images_per_prompt , output_hidden_state
677
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
678
+ ip_adapter_image_embeds = self .prepare_ip_adapter_image_embeds (
679
+ ip_adapter_image ,
680
+ ip_adapter_image_embeds ,
681
+ device ,
682
+ batch_size * num_images_per_prompt ,
683
+ do_classifier_free_guidance ,
643
684
)
685
+
644
686
if do_classifier_free_guidance :
645
- image_embeds = torch .cat ([negative_image_embeds , image_embeds ])
687
+ image_embeds = []
688
+ negative_image_embeds = []
689
+ for tmp_image_embeds in ip_adapter_image_embeds :
690
+ single_negative_image_embeds , single_image_embeds = tmp_image_embeds .chunk (2 )
691
+ image_embeds .append (single_image_embeds )
692
+ negative_image_embeds .append (single_negative_image_embeds )
693
+ else :
694
+ image_embeds = ip_adapter_image_embeds
646
695
647
696
# 3. Encode input prompt
648
697
prompt_embeds , negative_prompt_embeds = self .encode_prompt (
@@ -687,8 +736,18 @@ def __call__(
687
736
extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
688
737
689
738
# 6.1 Add image embeds for IP-Adapter
690
- added_cond_kwargs = {"image_embeds" : image_embeds } if ip_adapter_image is not None else None
691
- added_uncond_kwargs = {"image_embeds" : negative_image_embeds } if ip_adapter_image is not None else None
739
+ added_cond_kwargs = (
740
+ {"image_embeds" : image_embeds }
741
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
742
+ else None
743
+ )
744
+
745
+ if do_classifier_free_guidance :
746
+ added_uncond_kwargs = (
747
+ {"image_embeds" : negative_image_embeds }
748
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
749
+ else None
750
+ )
692
751
693
752
# 7. Denoising loop
694
753
store_processor = CrossAttnStoreProcessor ()
0 commit comments