@@ -233,10 +233,11 @@ def encode_image(self, components, image, device, num_images_per_prompt, output_
233233
234234    # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds 
235235    def  prepare_ip_adapter_image_embeds (
236-         self , components , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt 
236+         self , components , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt ,  do_classifier_free_guidance 
237237    ):
238238        image_embeds  =  []
239-         negative_image_embeds  =  []
239+         if  do_classifier_free_guidance :
240+             negative_image_embeds  =  []
240241        if  ip_adapter_image_embeds  is  None :
241242            if  not  isinstance (ip_adapter_image , list ):
242243                ip_adapter_image  =  [ip_adapter_image ]
@@ -255,18 +256,21 @@ def prepare_ip_adapter_image_embeds(
255256                )
256257
257258                image_embeds .append (single_image_embeds [None , :])
258-                 negative_image_embeds .append (single_negative_image_embeds [None , :])
259+                 if  do_classifier_free_guidance :
260+                     negative_image_embeds .append (single_negative_image_embeds [None , :])
259261        else :
260262            for  single_image_embeds  in  ip_adapter_image_embeds :
261-                 single_negative_image_embeds , single_image_embeds  =  single_image_embeds .chunk (2 )
262-                 negative_image_embeds .append (single_negative_image_embeds )
263+                 if  do_classifier_free_guidance :
264+                     single_negative_image_embeds , single_image_embeds  =  single_image_embeds .chunk (2 )
265+                     negative_image_embeds .append (single_negative_image_embeds )
263266                image_embeds .append (single_image_embeds )
264267
265268        ip_adapter_image_embeds  =  []
266269        for  i , single_image_embeds  in  enumerate (image_embeds ):
267270            single_image_embeds  =  torch .cat ([single_image_embeds ] *  num_images_per_prompt , dim = 0 )
268-             single_negative_image_embeds  =  torch .cat ([negative_image_embeds [i ]] *  num_images_per_prompt , dim = 0 )
269-             single_image_embeds  =  torch .cat ([single_negative_image_embeds , single_image_embeds ], dim = 0 )
271+             if  do_classifier_free_guidance :
272+                 single_negative_image_embeds  =  torch .cat ([negative_image_embeds [i ]] *  num_images_per_prompt , dim = 0 )
273+                 single_image_embeds  =  torch .cat ([single_negative_image_embeds , single_image_embeds ], dim = 0 )
270274
271275            single_image_embeds  =  single_image_embeds .to (device = device )
272276            ip_adapter_image_embeds .append (single_image_embeds )
0 commit comments