1919import PIL
2020import regex as re
2121import torch
22- from transformers import AutoTokenizer , CLIPImageProcessor , CLIPVisionModel , UMT5EncoderModel
22+ from transformers import AutoTokenizer , CLIPImageProcessor , CLIPVisionModelWithProjection , UMT5EncoderModel
2323
2424from ...callbacks import MultiPipelineCallbacks , PipelineCallback
2525from ...image_processor import PipelineImageInput
4949 >>> import numpy as np
5050 >>> from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
5151 >>> from diffusers.utils import export_to_video, load_image
52- >>> from transformers import CLIPVisionModel
52+ >>> from transformers import CLIPVisionModelWithProjection
5353
5454 >>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
5555 >>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
56- >>> image_encoder = CLIPVisionModel .from_pretrained(
56+ >>> image_encoder = CLIPVisionModelWithProjection .from_pretrained(
5757 ... model_id, subfolder="image_encoder", torch_dtype=torch.float32
5858 ... )
5959 >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
@@ -109,14 +109,30 @@ def prompt_clean(text):
109109
110110
111111def retrieve_latents (
112- encoder_output : torch .Tensor , generator : Optional [torch .Generator ] = None , sample_mode : str = "sample"
112+ encoder_output : torch .Tensor ,
113+ latents_mean : torch .Tensor ,
114+ latents_std : torch .Tensor ,
115+ generator : Optional [torch .Generator ] = None ,
116+ sample_mode : str = "sample" ,
113117):
114118 if hasattr (encoder_output , "latent_dist" ) and sample_mode == "sample" :
119+ encoder_output .latent_dist .mean = (encoder_output .latent_dist .mean - latents_mean ) * latents_std
120+ encoder_output .latent_dist .logvar = torch .clamp (
121+ (encoder_output .latent_dist .logvar - latents_mean ) * latents_std , - 30.0 , 20.0
122+ )
123+ encoder_output .latent_dist .std = torch .exp (0.5 * encoder_output .latent_dist .logvar )
124+ encoder_output .latent_dist .var = torch .exp (encoder_output .latent_dist .logvar )
115125 return encoder_output .latent_dist .sample (generator )
116126 elif hasattr (encoder_output , "latent_dist" ) and sample_mode == "argmax" :
127+ encoder_output .latent_dist .mean = (encoder_output .latent_dist .mean - latents_mean ) * latents_std
128+ encoder_output .latent_dist .logvar = torch .clamp (
129+ (encoder_output .latent_dist .logvar - latents_mean ) * latents_std , - 30.0 , 20.0
130+ )
131+ encoder_output .latent_dist .std = torch .exp (0.5 * encoder_output .latent_dist .logvar )
132+ encoder_output .latent_dist .var = torch .exp (encoder_output .latent_dist .logvar )
117133 return encoder_output .latent_dist .mode ()
118134 elif hasattr (encoder_output , "latents" ):
119- return encoder_output .latents
135+ return ( encoder_output .latents - latents_mean ) * latents_std
120136 else :
121137 raise AttributeError ("Could not access latents of provided encoder_output" )
122138
@@ -155,7 +171,7 @@ def __init__(
155171 self ,
156172 tokenizer : AutoTokenizer ,
157173 text_encoder : UMT5EncoderModel ,
158- image_encoder : CLIPVisionModel ,
174+ image_encoder : CLIPVisionModelWithProjection ,
159175 image_processor : CLIPImageProcessor ,
160176 transformer : WanTransformer3DModel ,
161177 vae : AutoencoderKLWan ,
@@ -385,13 +401,6 @@ def prepare_latents(
385401 )
386402 video_condition = video_condition .to (device = device , dtype = dtype )
387403
388- if isinstance (generator , list ):
389- latent_condition = [retrieve_latents (self .vae .encode (video_condition ), g ) for g in generator ]
390- latents = latent_condition = torch .cat (latent_condition )
391- else :
392- latent_condition = retrieve_latents (self .vae .encode (video_condition ), generator )
393- latent_condition = latent_condition .repeat (batch_size , 1 , 1 , 1 , 1 )
394-
395404 latents_mean = (
396405 torch .tensor (self .vae .config .latents_mean )
397406 .view (1 , self .vae .config .z_dim , 1 , 1 , 1 )
@@ -401,7 +410,14 @@ def prepare_latents(
401410 latents .device , latents .dtype
402411 )
403412
404- latent_condition = (latent_condition - latents_mean ) * latents_std
413+ if isinstance (generator , list ):
414+ latent_condition = [
415+ retrieve_latents (self .vae .encode (video_condition ), latents_mean , latents_std , g ) for g in generator
416+ ]
417+ latent_condition = torch .cat (latent_condition )
418+ else :
419+ latent_condition = retrieve_latents (self .vae .encode (video_condition ), latents_mean , latents_std , generator )
420+ latent_condition = latent_condition .repeat (batch_size , 1 , 1 , 1 , 1 )
405421
406422 mask_lat_size = torch .ones (batch_size , 1 , num_frames , latent_height , latent_width )
407423 mask_lat_size [:, :, list (range (1 , num_frames ))] = 0
0 commit comments