@@ -108,34 +108,16 @@ def prompt_clean(text):
108108    return  text 
109109
110110
111+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 
111112def  retrieve_latents (
112-     encoder_output : torch .Tensor ,
113-     latents_mean : torch .Tensor ,
114-     latents_std : torch .Tensor ,
115-     generator : Optional [torch .Generator ] =  None ,
116-     sample_mode : str  =  "argmax" ,
113+     encoder_output : torch .Tensor , generator : Optional [torch .Generator ] =  None , sample_mode : str  =  "sample" 
117114):
118-     if  hasattr (encoder_output , "latent_dist" ) and  sample_mode  ==  "argmax" :
119-         return  (encoder_output .latent_dist .mean  -  latents_mean ) *  latents_std 
120-     else :
121-            raise  AttributeError ("Could not access latents of provided encoder_output" )
122-         encoder_output .latent_dist .mean  =  (encoder_output .latent_dist .mean  -  latents_mean ) *  latents_std 
123-         encoder_output .latent_dist .logvar  =  torch .clamp (
124-             (encoder_output .latent_dist .logvar  -  latents_mean ) *  latents_std , - 30.0 , 20.0 
125-         )
126-         encoder_output .latent_dist .std  =  torch .exp (0.5  *  encoder_output .latent_dist .logvar )
127-         encoder_output .latent_dist .var  =  torch .exp (encoder_output .latent_dist .logvar )
115+     if  hasattr (encoder_output , "latent_dist" ) and  sample_mode  ==  "sample" :
128116        return  encoder_output .latent_dist .sample (generator )
129117    elif  hasattr (encoder_output , "latent_dist" ) and  sample_mode  ==  "argmax" :
130-         encoder_output .latent_dist .mean  =  (encoder_output .latent_dist .mean  -  latents_mean ) *  latents_std 
131-         encoder_output .latent_dist .logvar  =  torch .clamp (
132-             (encoder_output .latent_dist .logvar  -  latents_mean ) *  latents_std , - 30.0 , 20.0 
133-         )
134-         encoder_output .latent_dist .std  =  torch .exp (0.5  *  encoder_output .latent_dist .logvar )
135-         encoder_output .latent_dist .var  =  torch .exp (encoder_output .latent_dist .logvar )
136118        return  encoder_output .latent_dist .mode ()
137119    elif  hasattr (encoder_output , "latents" ):
138-         return  ( encoder_output .latents   -   latents_mean )  *   latents_std 
120+         return  encoder_output .latents 
139121    else :
140122        raise  AttributeError ("Could not access latents of provided encoder_output" )
141123
@@ -415,13 +397,15 @@ def prepare_latents(
415397
416398        if  isinstance (generator , list ):
417399            latent_condition  =  [
418-                 retrieve_latents (self .vae .encode (video_condition ), latents_mean ,  latents_std ,  g ) for  g  in  generator 
400+                 retrieve_latents (self .vae .encode (video_condition ), sample_mode = "argmax" ) for  _  in  generator 
419401            ]
420402            latent_condition  =  torch .cat (latent_condition )
421403        else :
422-             latent_condition  =  retrieve_latents (self .vae .encode (video_condition ), latents_mean ,  latents_std ,  generator )
404+             latent_condition  =  retrieve_latents (self .vae .encode (video_condition ), sample_mode = "argmax" )
423405            latent_condition  =  latent_condition .repeat (batch_size , 1 , 1 , 1 , 1 )
424406
407+         latent_condition  =  (latent_condition  -  latents_mean ) *  latents_std 
408+ 
425409        mask_lat_size  =  torch .ones (batch_size , 1 , num_frames , latent_height , latent_width )
426410        mask_lat_size [:, :, list (range (1 , num_frames ))] =  0 
427411        first_frame_mask  =  mask_lat_size [:, :, 0 :1 ]
0 commit comments