Skip to content

Commit d8ca6a8

Browse files
authored
Update src/diffusers/utils/torch_utils.py
1 parent 7a88918 commit d8ca6a8

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

src/diffusers/utils/torch_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,12 @@ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.T
9898
"""
9999
x = x_in
100100
B, C, H, W = x.shape
101-
x = x.to(dtype=torch.float32)
101+
# Non-power of 2 images must be float32
102+
if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
103+
x = x.to(dtype=torch.float32)
104+
# fftn does not support bfloat16
105+
elif x.dtype == torch.bfloat16:
106+
x = x.to(dtype=torch.float32)
102107

103108
# FFT
104109
x_freq = fftn(x, dim=(-2, -1))

0 commit comments

Comments
 (0)