Skip to content

Commit 1978fb9

Browse files
committed
WanI2V encode_image
1 parent de6a88c commit 1978fb9

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -220,10 +220,29 @@ def _get_t5_prompt_embeds(
220220

221221
return prompt_embeds
222222

223-
def encode_image(self, image: PipelineImageInput):
224-
image = self.image_processor(images=image, return_tensors="pt").to(self.device)
225-
image_embeds = self.image_encoder(**image, output_hidden_states=True)
226-
return image_embeds.hidden_states[-2]
223+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
224+
dtype = next(self.image_encoder.parameters()).dtype
225+
226+
if not isinstance(image, torch.Tensor):
227+
image = self.image_processor(image, return_tensors="pt").pixel_values
228+
229+
image = image.to(device=device, dtype=dtype)
230+
if output_hidden_states:
231+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
232+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
233+
uncond_image_enc_hidden_states = self.image_encoder(
234+
torch.zeros_like(image), output_hidden_states=True
235+
).hidden_states[-2]
236+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
237+
num_images_per_prompt, dim=0
238+
)
239+
return image_enc_hidden_states, uncond_image_enc_hidden_states
240+
else:
241+
image_embeds = self.image_encoder(image).image_embeds
242+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
243+
uncond_image_embeds = torch.zeros_like(image_embeds)
244+
245+
return image_embeds, uncond_image_embeds
227246

228247
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
229248
def encode_prompt(
@@ -587,8 +606,7 @@ def __call__(
587606
if negative_prompt_embeds is not None:
588607
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
589608

590-
image_embeds = self.encode_image(image)
591-
image_embeds = image_embeds.repeat(batch_size, 1, 1)
609+
image_embeds, _ = self.encode_image(image, device, num_videos_per_prompt, output_hidden_states=True)
592610
image_embeds = image_embeds.to(transformer_dtype)
593611

594612
# 4. Prepare timesteps

0 commit comments

Comments
 (0)