@@ -415,6 +415,24 @@ def _unpack_latents(
415415        latents  =  latents .permute (0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 ).flatten (6 , 7 ).flatten (4 , 5 ).flatten (2 , 3 )
416416        return  latents 
417417
418+     @staticmethod  
419+     def  _normalize_latents (
420+         latents : torch .Tensor , latents_mean : torch .Tensor , latents_std : torch .Tensor , scaling_factor : float  =  1.0 
421+     ) ->  torch .Tensor :
422+         latents_mean  =  latents_mean .view (1 , - 1 , 1 , 1 , 1 ).to (latents .device , latents .dtype )
423+         latents_std  =  latents_std .view (1 , - 1 , 1 , 1 , 1 ).to (latents .device , latents .dtype )
424+         latents  =  (latents  -  latents_mean ) *  scaling_factor  /  latents_std 
425+         return  latents 
426+ 
427+     @staticmethod  
428+     def  _denormalize_latents (
429+         latents : torch .Tensor , latents_mean : torch .Tensor , latents_std : torch .Tensor , scaling_factor : float  =  1.0 
430+     ) ->  torch .Tensor :
431+         latents_mean  =  latents_mean .view (1 , - 1 , 1 , 1 , 1 ).to (latents .device , latents .dtype )
432+         latents_std  =  latents_std .view (1 , - 1 , 1 , 1 , 1 ).to (latents .device , latents .dtype )
433+         latents  =  latents  *  latents_std  /  scaling_factor  +  latents_mean 
434+         return  latents 
435+ 
418436    def  prepare_latents (
419437        self ,
420438        batch_size : int  =  1 ,
@@ -443,7 +461,9 @@ def prepare_latents(
443461            )
444462
445463        latents  =  randn_tensor (shape , generator = generator , device = device , dtype = dtype )
446-         latents  =  self ._pack_latents (latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size )
464+         latents  =  self ._pack_latents (
465+             latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size 
466+         )
447467        return  latents 
448468
449469    @property  
@@ -709,15 +729,17 @@ def __call__(
709729        if  output_type  ==  "latent" :
710730            video  =  latents 
711731        else :
712-             latents  =  self ._unpack_latents (latents , latent_num_frames , latent_height , latent_width , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size )
713-             # unscale/denormalize the latents 
714-             latents_mean  =  self .vae .latents_mean .view (1 , self .vae .config .latent_channels , 1 , 1 , 1 ).to (
715-                 latents .device , latents .dtype 
732+             latents  =  self ._unpack_latents (
733+                 latents ,
734+                 latent_num_frames ,
735+                 latent_height ,
736+                 latent_width ,
737+                 self .transformer_spatial_patch_size ,
738+                 self .transformer_temporal_patch_size ,
716739            )
717-             latents_std  =  self .vae . latents_std . view ( 1 ,  self . vae . config . latent_channels ,  1 ,  1 ,  1 ). to (
718-                 latents . device ,  latents . dtype 
740+             latents  =  self ._denormalize_latents (
741+                 latents ,  self . vae . latents_mean ,  self . vae . latents_std ,  self . vae . config . scaling_factor 
719742            )
720-             latents  =  latents  *  latents_std  /  self .vae .config .scaling_factor  +  latents_mean 
721743            video  =  self .vae .decode (latents , return_dict = False )[0 ]
722744            video  =  self .video_processor .postprocess_video (video , output_type = output_type )
723745
0 commit comments