@@ -695,42 +695,22 @@ def encode_image(self, image):
695695    def  prepare_ip_adapter_image_embeds (
696696        self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt , do_classifier_free_guidance 
697697    ):
698-         # image_embeds = [] 
699- 
700-         # if do_classifier_free_guidance: 
701-         #     negative_image_embeds = [] 
702- 
703-         # if ip_adapter_image_embeds is None: 
704-         #         single_image_embeds, single_negative_image_embeds = self.encode_image(ip_adapter_image) 
705-         #         image_embeds.append(single_image_embeds[None, :]) 
706-                 
707-         #         if do_classifier_free_guidance: 
708-         #             negative_image_embeds.append(single_negative_image_embeds[None, :]) 
709-         # else: 
710-         #     for single_image_embeds in ip_adapter_image_embeds: 
711-         #         if do_classifier_free_guidance: 
712-         #             single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) 
713-         #             negative_image_embeds.append(single_negative_image_embeds) 
714-         #         image_embeds.append(single_image_embeds) 
715- 
716-         # ip_adapter_image_embeds = [] 
717-         # for i, single_image_embeds in enumerate(image_embeds): 
718-         #     single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) 
719-             
720-         #     if do_classifier_free_guidance: 
721-         #         single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) 
722-         #         single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) 
723- 
724-         #     single_image_embeds = single_image_embeds.to(device=device) 
725-         #     ip_adapter_image_embeds.append(single_image_embeds) 
726- 
727- 
728-         # Single image only :/ 
729-         clip_image_tensor  =  self .feature_extractor (images = ip_adapter_image , return_tensors = "pt" ).pixel_values 
730-         clip_image_tensor  =  clip_image_tensor .to (device , dtype = self .dtype )
731-         clip_image_embeds  =  self .image_encoder (clip_image_tensor , output_hidden_states = True ).hidden_states [- 2 ]
698+         if  ip_adapter_image_embeds  is  None :
699+                 single_image_embeds , single_negative_image_embeds  =  self .encode_image (ip_adapter_image )
700+         else :
701+             for  single_image_embeds  in  ip_adapter_image_embeds :
702+                 if  do_classifier_free_guidance :
703+                     single_negative_image_embeds , single_image_embeds  =  single_image_embeds .chunk (2 )
704+                 else :
705+                     single_image_embeds  =  ip_adapter_image_embeds 
706+ 
707+         single_image_embeds  =  torch .cat ([single_image_embeds ] *  num_images_per_prompt , dim = 0 )
732708
733-         return  torch .cat ([torch .zeros_like (clip_image_embeds ), clip_image_embeds ], dim = 0 )
709+         if  do_classifier_free_guidance :
710+             single_negative_image_embeds  =  torch .cat ([single_negative_image_embeds ] *  num_images_per_prompt , dim = 0 )
711+             single_image_embeds  =  torch .cat ([single_negative_image_embeds , single_image_embeds ], dim = 0 )
712+ 
713+         return  single_image_embeds .to (device = device )
734714
735715    @torch .no_grad () 
736716    @replace_example_docstring (EXAMPLE_DOC_STRING ) 
0 commit comments