@@ -220,29 +220,15 @@ def _get_t5_prompt_embeds(
220220
221221 return prompt_embeds
222222
223- def encode_image (self , image , device , num_images_per_prompt , output_hidden_states = None ):
224- dtype = next (self .image_encoder .parameters ()).dtype
225-
226- if not isinstance (image , torch .Tensor ):
227- image = self .image_processor (image , return_tensors = "pt" ).pixel_values
228-
229- image = image .to (device = device , dtype = dtype )
230- if output_hidden_states :
231- image_enc_hidden_states = self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
232- image_enc_hidden_states = image_enc_hidden_states .repeat_interleave (num_images_per_prompt , dim = 0 )
233- uncond_image_enc_hidden_states = self .image_encoder (
234- torch .zeros_like (image ), output_hidden_states = True
235- ).hidden_states [- 2 ]
236- uncond_image_enc_hidden_states = uncond_image_enc_hidden_states .repeat_interleave (
237- num_images_per_prompt , dim = 0
238- )
239- return image_enc_hidden_states , uncond_image_enc_hidden_states
240- else :
241- image_embeds = self .image_encoder (image ).image_embeds
242- image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
243- uncond_image_embeds = torch .zeros_like (image_embeds )
244-
245- return image_embeds , uncond_image_embeds
223+ def encode_image (
224+ self ,
225+ image : PipelineImageInput ,
226+ device : Optional [torch .device ] = None ,
227+ ):
228+ device = device or self ._execution_device
229+ image = self .image_processor (images = image , return_tensors = "pt" ).to (device )
230+ image_embeds = self .image_encoder (** image , output_hidden_states = True )
231+ return image_embeds .hidden_states [- 2 ]
246232
247233 # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
248234 def encode_prompt (
@@ -606,7 +592,8 @@ def __call__(
606592 if negative_prompt_embeds is not None :
607593 negative_prompt_embeds = negative_prompt_embeds .to (transformer_dtype )
608594
609- image_embeds , _ = self .encode_image (image , device , num_videos_per_prompt , output_hidden_states = True )
595+ image_embeds = self .encode_image (image , device )
596+ image_embeds = image_embeds .repeat (batch_size , 1 , 1 )
610597 image_embeds = image_embeds .to (transformer_dtype )
611598
612599 # 4. Prepare timesteps
0 commit comments