Skip to content

Commit 84f6d84

Browse files
committed
1 parent 1978fb9 commit 84f6d84

File tree

1 file changed

+11
-24
lines changed

1 file changed

+11
-24
lines changed

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -220,29 +220,15 @@ def _get_t5_prompt_embeds(
220220

221221
return prompt_embeds
222222

223-
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
224-
dtype = next(self.image_encoder.parameters()).dtype
225-
226-
if not isinstance(image, torch.Tensor):
227-
image = self.image_processor(image, return_tensors="pt").pixel_values
228-
229-
image = image.to(device=device, dtype=dtype)
230-
if output_hidden_states:
231-
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
232-
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
233-
uncond_image_enc_hidden_states = self.image_encoder(
234-
torch.zeros_like(image), output_hidden_states=True
235-
).hidden_states[-2]
236-
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
237-
num_images_per_prompt, dim=0
238-
)
239-
return image_enc_hidden_states, uncond_image_enc_hidden_states
240-
else:
241-
image_embeds = self.image_encoder(image).image_embeds
242-
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
243-
uncond_image_embeds = torch.zeros_like(image_embeds)
244-
245-
return image_embeds, uncond_image_embeds
223+
def encode_image(
224+
self,
225+
image: PipelineImageInput,
226+
device: Optional[torch.device] = None,
227+
):
228+
device = device or self._execution_device
229+
image = self.image_processor(images=image, return_tensors="pt").to(device)
230+
image_embeds = self.image_encoder(**image, output_hidden_states=True)
231+
return image_embeds.hidden_states[-2]
246232

247233
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
248234
def encode_prompt(
@@ -606,7 +592,8 @@ def __call__(
606592
if negative_prompt_embeds is not None:
607593
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
608594

609-
image_embeds, _ = self.encode_image(image, device, num_videos_per_prompt, output_hidden_states=True)
595+
image_embeds = self.encode_image(image, device)
596+
image_embeds = image_embeds.repeat(batch_size, 1, 1)
610597
image_embeds = image_embeds.to(transformer_dtype)
611598

612599
# 4. Prepare timesteps

0 commit comments

Comments
 (0)