@@ -109,14 +109,30 @@ def prompt_clean(text):
109109
110110
111111def retrieve_latents (
112- encoder_output : torch .Tensor , generator : Optional [torch .Generator ] = None , sample_mode : str = "sample"
112+ encoder_output : torch .Tensor ,
113+ latents_mean : torch .Tensor ,
114+ latents_std : torch .Tensor ,
115+ generator : Optional [torch .Generator ] = None ,
116+ sample_mode : str = "sample" ,
113117):
114118 if hasattr (encoder_output , "latent_dist" ) and sample_mode == "sample" :
119+ encoder_output .latent_dist .mean = (encoder_output .latent_dist .mean - latents_mean ) * latents_std
120+ encoder_output .latent_dist .logvar = torch .clamp (
121+ (encoder_output .latent_dist .logvar - latents_mean ) * latents_std , - 30.0 , 20.0
122+ )
123+ encoder_output .latent_dist .std = torch .exp (0.5 * encoder_output .latent_dist .logvar )
124+ encoder_output .latent_dist .var = torch .exp (encoder_output .latent_dist .logvar )
115125 return encoder_output .latent_dist .sample (generator )
116126 elif hasattr (encoder_output , "latent_dist" ) and sample_mode == "argmax" :
127+ encoder_output .latent_dist .mean = (encoder_output .latent_dist .mean - latents_mean ) * latents_std
128+ encoder_output .latent_dist .logvar = torch .clamp (
129+ (encoder_output .latent_dist .logvar - latents_mean ) * latents_std , - 30.0 , 20.0
130+ )
131+ encoder_output .latent_dist .std = torch .exp (0.5 * encoder_output .latent_dist .logvar )
132+ encoder_output .latent_dist .var = torch .exp (encoder_output .latent_dist .logvar )
117133 return encoder_output .latent_dist .mode ()
118134 elif hasattr (encoder_output , "latents" ):
119- return encoder_output .latents
135+ return ( encoder_output .latents - latents_mean ) * latents_std
120136 else :
121137 raise AttributeError ("Could not access latents of provided encoder_output" )
122138
@@ -385,13 +401,6 @@ def prepare_latents(
385401 )
386402 video_condition = video_condition .to (device = device , dtype = dtype )
387403
388- if isinstance (generator , list ):
389- latent_condition = [retrieve_latents (self .vae .encode (video_condition ), g ) for g in generator ]
390- latents = latent_condition = torch .cat (latent_condition )
391- else :
392- latent_condition = retrieve_latents (self .vae .encode (video_condition ), generator )
393- latent_condition = latent_condition .repeat (batch_size , 1 , 1 , 1 , 1 )
394-
395404 latents_mean = (
396405 torch .tensor (self .vae .config .latents_mean )
397406 .view (1 , self .vae .config .z_dim , 1 , 1 , 1 )
@@ -401,7 +410,14 @@ def prepare_latents(
401410 latents .device , latents .dtype
402411 )
403412
404- latent_condition = (latent_condition - latents_mean ) * latents_std
413+ if isinstance (generator , list ):
414+ latent_condition = [
415+ retrieve_latents (self .vae .encode (video_condition ), latents_mean , latents_std , g ) for g in generator
416+ ]
417+ latent_condition = torch .cat (latent_condition )
418+ else :
419+ latent_condition = retrieve_latents (self .vae .encode (video_condition ), latents_mean , latents_std , generator )
420+ latent_condition = latent_condition .repeat (batch_size , 1 , 1 , 1 , 1 )
405421
406422 mask_lat_size = torch .ones (batch_size , 1 , num_frames , latent_height , latent_width )
407423 mask_lat_size [:, :, list (range (1 , num_frames ))] = 0
0 commit comments