Skip to content

Commit e7b54d3

Browse files
committed
Avoid signal duplication after padding
1 parent e4bc454 commit e7b54d3

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

fft_conv_pytorch/fft_conv.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,14 @@ def fft_conv(
9999

100100
# Because PyTorch computes a *one-sided* FFT, we need the final dimension to
101101
# have *even* length. Just pad with one more zero if the final dimension is odd.
102+
signal_size = signal.size() # original signal size without padding to even
102103
if signal.size(-1) % 2 != 0:
103-
signal_ = f.pad(signal, [0, 1])
104-
else:
105-
signal_ = signal
104+
signal = f.pad(signal, [0, 1])
106105

107106
kernel_padding = [
108107
pad
109-
for i in reversed(range(2, signal_.ndim))
110-
for pad in [0, signal_.size(i) - kernel.size(i)]
108+
for i in reversed(range(2, signal.ndim))
109+
for pad in [0, signal.size(i) - kernel.size(i)]
111110
]
112111
padded_kernel = f.pad(kernel, kernel_padding)
113112

@@ -121,8 +120,8 @@ def fft_conv(
121120
output = irfftn(output_fr, dim=tuple(range(2, signal.ndim)))
122121

123122
# Remove extra padded values
124-
crop_slices = [slice(0, output.size(0)), slice(0, output.size(1))] + [
125-
slice(0, (signal.size(i) - kernel.size(i) + 1), stride_[i - 2])
123+
crop_slices = [slice(None), slice(None)] + [
124+
slice(0, (signal_size[i] - kernel.size(i) + 1), stride_[i - 2])
126125
for i in range(2, signal.ndim)
127126
]
128127
output = output[crop_slices].contiguous()

0 commit comments

Comments
 (0)