Skip to content

Commit 40fcd51

Browse files
committed
Convert signal and kernel to float before rfftn to fix half precision error
1 parent e7b54d3 commit 40fcd51

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

fft_conv_pytorch/fft_conv.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,8 @@ def fft_conv(
111111
padded_kernel = f.pad(kernel, kernel_padding)
112112

113113
# Perform fourier convolution -- FFT, matrix multiply, then IFFT
114-
# signal_ = signal_.reshape(signal_.size(0), groups, -1, *signal_.shape[2:])
115-
signal_fr = rfftn(signal_, dim=tuple(range(2, signal.ndim)))
116-
kernel_fr = rfftn(padded_kernel, dim=tuple(range(2, signal.ndim)))
114+
signal_fr = rfftn(signal.float(), dim=tuple(range(2, signal.ndim)))
115+
kernel_fr = rfftn(padded_kernel.float(), dim=tuple(range(2, signal.ndim)))
117116

118117
kernel_fr.imag *= -1
119118
output_fr = complex_matmul(signal_fr, kernel_fr, groups=groups)

0 commit comments

Comments
 (0)