@@ -680,7 +680,16 @@ def interrupt(self):
680680        return  self ._interrupt 
681681
682682    # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image 
683-     def  encode_image (self , image ):
683+     def  encode_image (self , image : PipelineImageInput ) ->  torch .Tensor :
684+         """Encodes the given image into a feature representation using a pre-trained image encoder. 
685+ 
686+         Args: 
687+             image (`PipelineImageInput`): 
688+                 Input image to be encoded. 
689+ 
690+         Returns: 
691+             `torch.Tensor`: The encoded image feature representation. 
692+         """ 
684693        if  not  isinstance (image , torch .Tensor ):
685694            image  =  self .feature_extractor (image , return_tensors = "pt" ).pixel_values 
686695
@@ -690,17 +699,42 @@ def encode_image(self, image):
690699
691700    # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds 
692701    def  prepare_ip_adapter_image_embeds (
693-         self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt , do_classifier_free_guidance 
694-     ):
695-         if  ip_adapter_image_embeds  is  None :
702+         self ,
703+         ip_adapter_image : Optional [PipelineImageInput ] =  None ,
704+         ip_adapter_image_embeds : Optional [torch .Tensor ] =  None ,
705+         device : Optional [torch .device ] =  None ,
706+         num_images_per_prompt : int  =  1 ,
707+         do_classifier_free_guidance : bool  =  True ,
708+     ) ->  torch .Tensor :
709+         """Prepares image embeddings for use in the IP-Adapter. 
710+ 
711+         Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed. 
712+ 
713+         Args: 
714+             ip_adapter_image (`PipelineImageInput`, *optional*): 
715+                 The input image to extract features from for IP-Adapter. 
716+             ip_adapter_image_embeds (`torch.Tensor`, *optional*): 
717+                 Precomputed image embeddings. 
718+             device: (`torch.device`, *optional*): 
719+                 Torch device. 
720+             num_images_per_prompt (`int`, defaults to 1): 
721+                 Number of images that should be generated per prompt. 
722+             do_classifier_free_guidance (`bool`, defaults to True): 
723+                 Whether to use classifier free guidance or not. 
724+         """ 
725+         device  =  device  or  self ._execution_device 
726+ 
727+         if  ip_adapter_image_embeds  is  not None :
728+             if  do_classifier_free_guidance :
729+                 single_negative_image_embeds , single_image_embeds  =  ip_adapter_image_embeds .chunk (2 )
730+             else :
731+                 single_image_embeds  =  ip_adapter_image_embeds 
732+         elif  ip_adapter_image  is  not None :
696733            single_image_embeds  =  self .encode_image (ip_adapter_image )
697734            if  do_classifier_free_guidance :
698735                single_negative_image_embeds  =  torch .zeros_like (single_image_embeds )
699736        else :
700-             if  do_classifier_free_guidance :
701-                 single_negative_image_embeds , single_image_embeds  =  single_image_embeds .chunk (2 )
702-             else :
703-                 single_image_embeds  =  ip_adapter_image_embeds 
737+             raise  ValueError ("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided." )
704738
705739        image_embeds  =  torch .cat ([single_image_embeds ] *  num_images_per_prompt , dim = 0 )
706740
@@ -733,7 +767,7 @@ def __call__(
733767        pooled_prompt_embeds : Optional [torch .FloatTensor ] =  None ,
734768        negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] =  None ,
735769        ip_adapter_image : Optional [PipelineImageInput ] =  None ,
736-         ip_adapter_image_embeds : Optional [List [ torch .Tensor ] ] =  None ,
770+         ip_adapter_image_embeds : Optional [torch .Tensor ] =  None ,
737771        output_type : Optional [str ] =  "pil" ,
738772        return_dict : bool  =  True ,
739773        joint_attention_kwargs : Optional [Dict [str , Any ]] =  None ,
@@ -810,11 +844,10 @@ def __call__(
810844                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` 
811845                input argument. 
812846            ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. 
813-             ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): 
814-                 Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of 
815-                 IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should 
816-                 contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not 
817-                 provided, embeddings are computed from the `ip_adapter_image` input argument. 
847+             ip_adapter_image_embeds (`torch.Tensor`, *optional*): 
848+                 Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images, 
849+                 emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to 
850+                 `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. 
818851            output_type (`str`, *optional*, defaults to `"pil"`): 
819852                The output format of the generate image. Choose between 
820853                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 
@@ -950,8 +983,6 @@ def __call__(
950983        )
951984
952985        # 6. Prepare image embeddings 
953-         # Either image is passed and ip_adapter is active 
954-         # Or image_embeds are passed directly 
955986        if  (ip_adapter_image  is  not None  and  self .is_ip_adapter_active ) or  ip_adapter_image_embeds  is  not None :
956987            ip_adapter_image_embeds  =  self .prepare_ip_adapter_image_embeds (
957988                ip_adapter_image ,
0 commit comments