1919import PIL
2020import regex as re
2121import torch
22- from transformers import AutoTokenizer , CLIPImageProcessor , CLIPVisionModel , UMT5EncoderModel
22+ from transformers import AutoTokenizer , CLIPImageProcessor , CLIPVisionModelWithProjection , UMT5EncoderModel
2323
2424from ...callbacks import MultiPipelineCallbacks , PipelineCallback
2525from ...image_processor import PipelineImageInput
@@ -137,7 +137,7 @@ def __init__(
137137 self ,
138138 tokenizer : AutoTokenizer ,
139139 text_encoder : UMT5EncoderModel ,
140- image_encoder : CLIPVisionModel ,
140+ image_encoder : CLIPVisionModelWithProjection ,
141141 image_processor : CLIPImageProcessor ,
142142 transformer : WanTransformer3DModel ,
143143 vae : AutoencoderKLWan ,
@@ -345,9 +345,6 @@ def prepare_latents(
345345 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
346346 latents : Optional [torch .Tensor ] = None ,
347347 ) -> Tuple [torch .Tensor , torch .Tensor ]:
348- if latents is not None :
349- return latents .to (device = device , dtype = dtype )
350-
351348 num_latent_frames = (num_frames - 1 ) // self .vae_scale_factor_temporal + 1
352349 latent_height = height // self .vae_scale_factor_spatial
353350 latent_width = width // self .vae_scale_factor_spatial
@@ -359,11 +356,14 @@ def prepare_latents(
359356 f" size of { batch_size } . Make sure the batch size matches the length of the generators."
360357 )
361358
362- latents = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
359+ if latents is None :
360+ latents = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
361+ else :
362+ latents = latents .to (device = device , dtype = dtype )
363363
364364 image = image .unsqueeze (2 )
365365 video_condition = torch .cat (
366- [image , torch . zeros (image .shape [0 ], image .shape [1 ], num_frames - 1 , height , width )], dim = 2
366+ [image , image . new_zeros (image .shape [0 ], image .shape [1 ], num_frames - 1 , height , width )], dim = 2
367367 )
368368 video_condition = video_condition .to (device = device , dtype = dtype )
369369
@@ -564,7 +564,7 @@ def __call__(
564564 timesteps = self .scheduler .timesteps
565565
566566 # 5. Prepare latent variables
567- num_channels_latents = self .transformer .config .in_channels
567+ num_channels_latents = self .vae .config .z_dim
568568 image = self .video_processor .preprocess (image , height = height , width = width ).to (device , dtype = torch .float32 )
569569 latents , condition = self .prepare_latents (
570570 image ,
0 commit comments