From 2b8f0c68cf3838ce3b623d9bc290e6f02d9600ce Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 22 Nov 2024 16:10:11 +0530 Subject: [PATCH] compute fourier features in FP32. --- src/diffusers/models/autoencoders/autoencoder_kl_mochi.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py index 0eabf3a26d7c..920b0b62fef6 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py @@ -437,7 +437,8 @@ def __init__(self, start: int = 6, stop: int = 8, step: int = 1): def forward(self, inputs: torch.Tensor) -> torch.Tensor: r"""Forward method of the `FourierFeatures` class.""" - + original_dtype = inputs.dtype + inputs = inputs.to(torch.float32) num_channels = inputs.shape[1] num_freqs = (self.stop - self.start) // self.step @@ -450,7 +451,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: # Scale channels by frequency. h = w * h - return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1) + return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1).to(original_dtype) class MochiEncoder3D(nn.Module):