Skip to content

Commit 1a8d588

Browse files
authored
Add support to pass image embeddings to the pipeline.
It allows computing the image embeddings externally and use them.
1 parent 75d7e5c commit 1a8d588

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,15 @@ def _get_t5_prompt_embeds(
223223
def encode_image(
224224
self,
225225
image: PipelineImageInput,
226+
image_embeds: Optional[torch.Tensor] = None,
226227
device: Optional[torch.device] = None,
227228
):
228-
device = device or self._execution_device
229-
image = self.image_processor(images=image, return_tensors="pt").to(device)
230-
image_embeds = self.image_encoder(**image, output_hidden_states=True)
231-
return image_embeds.hidden_states[-2]
229+
if image_embeds is None:
230+
device = device or self._execution_device
231+
image = self.image_processor(images=image, return_tensors="pt").to(device)
232+
image_embeds = self.image_encoder(**image, output_hidden_states=True)
233+
image_embeds = image_embeds.hidden_states[-2]
234+
return image_embeds
232235

233236
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
234237
def encode_prompt(
@@ -321,9 +324,18 @@ def check_inputs(
321324
width,
322325
prompt_embeds=None,
323326
negative_prompt_embeds=None,
327+
image_embeds=None,
324328
callback_on_step_end_tensor_inputs=None,
325329
):
326-
if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
330+
if image is not None and image_embeds is not None:
331+
raise ValueError(
332+
f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
333+
" only forward one of the two."
334+
if image is None and image_embeds is None:
335+
raise ValueError(
336+
"Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
337+
)
338+
if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
327339
raise ValueError("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is" f" {type(image)}")
328340
if height % 16 != 0 or width % 16 != 0:
329341
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
@@ -463,6 +475,7 @@ def __call__(
463475
latents: Optional[torch.Tensor] = None,
464476
prompt_embeds: Optional[torch.Tensor] = None,
465477
negative_prompt_embeds: Optional[torch.Tensor] = None,
478+
image_embeds: Optional[torch.Tensor] = None,
466479
output_type: Optional[str] = "np",
467480
return_dict: bool = True,
468481
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -512,6 +525,12 @@ def __call__(
512525
prompt_embeds (`torch.Tensor`, *optional*):
513526
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
514527
provided, text embeddings are generated from the `prompt` input argument.
528+
negative_prompt_embeds (`torch.Tensor`, *optional*):
529+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
530+
provided, text embeddings are generated from the `negative_prompt` input argument.
531+
image_embeds (`torch.Tensor`, *optional*):
532+
Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not
533+
provided, image embeddings are generated from the `image` input argument.
515534
output_type (`str`, *optional*, defaults to `"pil"`):
516535
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
517536
return_dict (`bool`, *optional*, defaults to `True`):
@@ -592,7 +611,7 @@ def __call__(
592611
if negative_prompt_embeds is not None:
593612
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
594613

595-
image_embeds = self.encode_image(image, device)
614+
image_embeds = self.encode_image(image, image_embeds, device)
596615
image_embeds = image_embeds.repeat(batch_size, 1, 1)
597616
image_embeds = image_embeds.to(transformer_dtype)
598617

0 commit comments

Comments
 (0)