@@ -220,10 +220,29 @@ def _get_t5_prompt_embeds(
220220
221221 return prompt_embeds
222222
223- def encode_image (self , image : PipelineImageInput ):
224- image = self .image_processor (images = image , return_tensors = "pt" ).to (self .device )
225- image_embeds = self .image_encoder (** image , output_hidden_states = True )
226- return image_embeds .hidden_states [- 2 ]
223+ def encode_image (self , image , device , num_images_per_prompt , output_hidden_states = None ):
224+ dtype = next (self .image_encoder .parameters ()).dtype
225+
226+ if not isinstance (image , torch .Tensor ):
227+ image = self .image_processor (image , return_tensors = "pt" ).pixel_values
228+
229+ image = image .to (device = device , dtype = dtype )
230+ if output_hidden_states :
231+ image_enc_hidden_states = self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
232+ image_enc_hidden_states = image_enc_hidden_states .repeat_interleave (num_images_per_prompt , dim = 0 )
233+ uncond_image_enc_hidden_states = self .image_encoder (
234+ torch .zeros_like (image ), output_hidden_states = True
235+ ).hidden_states [- 2 ]
236+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states .repeat_interleave (
237+ num_images_per_prompt , dim = 0
238+ )
239+ return image_enc_hidden_states , uncond_image_enc_hidden_states
240+ else :
241+ image_embeds = self .image_encoder (image ).image_embeds
242+ image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
243+ uncond_image_embeds = torch .zeros_like (image_embeds )
244+
245+ return image_embeds , uncond_image_embeds
227246
228247 # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
229248 def encode_prompt (
@@ -587,8 +606,7 @@ def __call__(
587606 if negative_prompt_embeds is not None :
588607 negative_prompt_embeds = negative_prompt_embeds .to (transformer_dtype )
589608
590- image_embeds = self .encode_image (image )
591- image_embeds = image_embeds .repeat (batch_size , 1 , 1 )
609+ image_embeds , _ = self .encode_image (image , device , num_videos_per_prompt , output_hidden_states = True )
592610 image_embeds = image_embeds .to (transformer_dtype )
593611
594612 # 4. Prepare timesteps
0 commit comments