Skip to content

Commit e25f141

Browse files
authored
Update torch_utils.py
1 parent 8141aa6 commit e25f141

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/diffusers/utils/torch_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,14 @@ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.T
9898
"""
9999
x = x_in
100100
B, C, H, W = x.shape
101+
101102
# Non-power of 2 images must be float32
102103
if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
103104
x = x.to(dtype=torch.float32)
104105
# fftn does not support bfloat16
105106
elif x.dtype == torch.bfloat16:
106107
x = x.to(dtype=torch.float32)
107-
108+
108109
# FFT
109110
x_freq = fftn(x, dim=(-2, -1))
110111
x_freq = fftshift(x_freq, dim=(-2, -1))

0 commit comments

Comments
 (0)