diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 0cf75b4fad4e..12eef8899bbb 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -102,6 +102,9 @@ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.T # Non-power of 2 images must be float32 if (W & (W - 1)) != 0 or (H & (H - 1)) != 0: x = x.to(dtype=torch.float32) + # fftn does not support bfloat16 + elif x.dtype == torch.bfloat16: + x = x.to(dtype=torch.float32) # FFT x_freq = fftn(x, dim=(-2, -1))