Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/diffusers/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,8 @@ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.T
"""
x = x_in
B, C, H, W = x.shape

# Non-power of 2 images must be float32
if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
x = x.to(dtype=torch.float32)

x = x.to(dtype=torch.float32)

# FFT
x_freq = fftn(x, dim=(-2, -1))
x_freq = fftshift(x_freq, dim=(-2, -1))
Expand Down