Skip to content

Commit dded243

Browse files
committed
update
1 parent 3ffa711 commit dded243

File tree

1 file changed

+22
-15
lines changed

1 file changed

+22
-15
lines changed

src/diffusers/pipelines/mochi/pipeline_mochi.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -722,23 +722,30 @@ def __call__(
722722
if output_type == "latent":
723723
video = latents
724724
else:
725-
# unscale/denormalize the latents
726-
# denormalize with the mean and std if available and not None
727-
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
728-
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
729-
if has_latents_mean and has_latents_std:
730-
latents_mean = (
731-
torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
725+
with torch.autocast("cuda", torch.float32):
726+
# unscale/denormalize the latents
727+
# denormalize with the mean and std if available and not None
728+
has_latents_mean = (
729+
hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
732730
)
733-
latents_std = (
734-
torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
735-
)
736-
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
737-
else:
738-
latents = latents / self.vae.config.scaling_factor
731+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
732+
if has_latents_mean and has_latents_std:
733+
latents_mean = (
734+
torch.tensor(self.vae.config.latents_mean)
735+
.view(1, 12, 1, 1, 1)
736+
.to(latents.device, latents.dtype)
737+
)
738+
latents_std = (
739+
torch.tensor(self.vae.config.latents_std)
740+
.view(1, 12, 1, 1, 1)
741+
.to(latents.device, latents.dtype)
742+
)
743+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
744+
else:
745+
latents = latents / self.vae.config.scaling_factor
739746

740-
video = self.vae.decode(latents, return_dict=False)[0]
741-
video = self.video_processor.postprocess_video(video, output_type=output_type)
747+
video = self.vae.decode(latents, return_dict=False)[0]
748+
video = self.video_processor.postprocess_video(video, output_type=output_type)
742749

743750
# Offload all models
744751
self.maybe_free_model_hooks()

0 commit comments

Comments
 (0)