@@ -373,18 +373,29 @@ def encode_prompt(
373373 return prompt_embeds , negative_prompt_embeds
374374
375375 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
376- def encode_image (self , image , device , num_images_per_prompt ):
376+ def encode_image (self , image , device , num_images_per_prompt , output_hidden_states = None ):
377377 dtype = next (self .image_encoder .parameters ()).dtype
378378
379379 if not isinstance (image , torch .Tensor ):
380380 image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
381381
382382 image = image .to (device = device , dtype = dtype )
383- image_embeds = self .image_encoder (image ).image_embeds
384- image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
383+ if output_hidden_states :
384+ image_enc_hidden_states = self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
385+ image_enc_hidden_states = image_enc_hidden_states .repeat_interleave (num_images_per_prompt , dim = 0 )
386+ uncond_image_enc_hidden_states = self .image_encoder (
387+ torch .zeros_like (image ), output_hidden_states = True
388+ ).hidden_states [- 2 ]
389+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states .repeat_interleave (
390+ num_images_per_prompt , dim = 0
391+ )
392+ return image_enc_hidden_states , uncond_image_enc_hidden_states
393+ else :
394+ image_embeds = self .image_encoder (image ).image_embeds
395+ image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
396+ uncond_image_embeds = torch .zeros_like (image_embeds )
385397
386- uncond_image_embeds = torch .zeros_like (image_embeds )
387- return image_embeds , uncond_image_embeds
398+ return image_embeds , uncond_image_embeds
388399
389400 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
390401 def prepare_ip_adapter_image_embeds (
0 commit comments