@@ -223,12 +223,15 @@ def _get_t5_prompt_embeds(
223223 def encode_image (
224224 self ,
225225 image : PipelineImageInput ,
226+ image_embeds : Optional [torch .Tensor ] = None ,
226227 device : Optional [torch .device ] = None ,
227228 ):
228- device = device or self ._execution_device
229- image = self .image_processor (images = image , return_tensors = "pt" ).to (device )
230- image_embeds = self .image_encoder (** image , output_hidden_states = True )
231- return image_embeds .hidden_states [- 2 ]
229+ if image_embeds is None :
230+ device = device or self ._execution_device
231+ image = self .image_processor (images = image , return_tensors = "pt" ).to (device )
232+ image_embeds = self .image_encoder (** image , output_hidden_states = True )
233+ image_embeds = image_embeds .hidden_states [- 2 ]
234+ return image_embeds
232235
233236 # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
234237 def encode_prompt (
@@ -321,9 +324,18 @@ def check_inputs(
321324 width ,
322325 prompt_embeds = None ,
323326 negative_prompt_embeds = None ,
327+ image_embeds = None ,
324328 callback_on_step_end_tensor_inputs = None ,
325329 ):
326- if not isinstance (image , torch .Tensor ) and not isinstance (image , PIL .Image .Image ):
330+ if image is not None and image_embeds is not None :
331+ raise ValueError (
332+ f"Cannot forward both `image`: { image } and `image_embeds`: { image_embeds } . Please make sure to"
333+ " only forward one of the two."
334+ if image is None and image_embeds is None :
335+ raise ValueError (
336+ "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
337+ )
338+ if image is not None and not isinstance (image , torch .Tensor ) and not isinstance (image , PIL .Image .Image ):
327339 raise ValueError ("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is" f" { type (image )} " )
328340 if height % 16 != 0 or width % 16 != 0 :
329341 raise ValueError (f"`height` and `width` have to be divisible by 16 but are { height } and { width } ." )
@@ -463,6 +475,7 @@ def __call__(
463475 latents : Optional [torch .Tensor ] = None ,
464476 prompt_embeds : Optional [torch .Tensor ] = None ,
465477 negative_prompt_embeds : Optional [torch .Tensor ] = None ,
478+ image_embeds : Optional [torch .Tensor ] = None ,
466479 output_type : Optional [str ] = "np" ,
467480 return_dict : bool = True ,
468481 attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -512,6 +525,12 @@ def __call__(
512525 prompt_embeds (`torch.Tensor`, *optional*):
513526 Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
514527 provided, text embeddings are generated from the `prompt` input argument.
528+ negative_prompt_embeds (`torch.Tensor`, *optional*):
529+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
530+ provided, text embeddings are generated from the `negative_prompt` input argument.
531+ image_embeds (`torch.Tensor`, *optional*):
532+ Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not
533+ provided, image embeddings are generated from the `image` input argument.
515534 output_type (`str`, *optional*, defaults to `"pil"`):
516535 The output format of the generated image. Choose between `PIL.Image` or `np.array`.
517536 return_dict (`bool`, *optional*, defaults to `True`):
@@ -592,7 +611,7 @@ def __call__(
592611 if negative_prompt_embeds is not None :
593612 negative_prompt_embeds = negative_prompt_embeds .to (transformer_dtype )
594613
595- image_embeds = self .encode_image (image , device )
614+ image_embeds = self .encode_image (image , image_embeds , device )
596615 image_embeds = image_embeds .repeat (batch_size , 1 , 1 )
597616 image_embeds = image_embeds .to (transformer_dtype )
598617
0 commit comments