@@ -722,23 +722,30 @@ def __call__(
722722 if output_type == "latent" :
723723 video = latents
724724 else :
725- # unscale/denormalize the latents
726- # denormalize with the mean and std if available and not None
727- has_latents_mean = hasattr (self .vae .config , "latents_mean" ) and self .vae .config .latents_mean is not None
728- has_latents_std = hasattr (self .vae .config , "latents_std" ) and self .vae .config .latents_std is not None
729- if has_latents_mean and has_latents_std :
730- latents_mean = (
731- torch .tensor (self .vae .config .latents_mean ).view (1 , 12 , 1 , 1 , 1 ).to (latents .device , latents .dtype )
725+ with torch .autocast ("cuda" , torch .float32 ):
726+ # unscale/denormalize the latents
727+ # denormalize with the mean and std if available and not None
728+ has_latents_mean = (
729+ hasattr (self .vae .config , "latents_mean" ) and self .vae .config .latents_mean is not None
732730 )
733- latents_std = (
734- torch .tensor (self .vae .config .latents_std ).view (1 , 12 , 1 , 1 , 1 ).to (latents .device , latents .dtype )
735- )
736- latents = latents * latents_std / self .vae .config .scaling_factor + latents_mean
737- else :
738- latents = latents / self .vae .config .scaling_factor
731+ has_latents_std = hasattr (self .vae .config , "latents_std" ) and self .vae .config .latents_std is not None
732+ if has_latents_mean and has_latents_std :
733+ latents_mean = (
734+ torch .tensor (self .vae .config .latents_mean )
735+ .view (1 , 12 , 1 , 1 , 1 )
736+ .to (latents .device , latents .dtype )
737+ )
738+ latents_std = (
739+ torch .tensor (self .vae .config .latents_std )
740+ .view (1 , 12 , 1 , 1 , 1 )
741+ .to (latents .device , latents .dtype )
742+ )
743+ latents = latents * latents_std / self .vae .config .scaling_factor + latents_mean
744+ else :
745+ latents = latents / self .vae .config .scaling_factor
739746
740- video = self .vae .decode (latents , return_dict = False )[0 ]
741- video = self .video_processor .postprocess_video (video , output_type = output_type )
747+ video = self .vae .decode (latents , return_dict = False )[0 ]
748+ video = self .video_processor .postprocess_video (video , output_type = output_type )
742749
743750 # Offload all models
744751 self .maybe_free_model_hooks ()
0 commit comments