latents_mean = (
    torch.tensor(self.vae.config.latents_mean)
    .view(1, self.vae.config.z_dim, 1, 1, 1)
    .to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
    latents.device, latents.dtype
)
latent_condition = (latent_condition - latents_mean) * latents_std