diff --git a/fft_conv_pytorch/fft_conv.py b/fft_conv_pytorch/fft_conv.py index 81a181e..3f59e6f 100644 --- a/fft_conv_pytorch/fft_conv.py +++ b/fft_conv_pytorch/fft_conv.py @@ -132,10 +132,12 @@ def fft_conv( output = irfftn(output_fr, dim=tuple(range(2, signal.ndim))) # Remove extra padded values - crop_slices = [slice(None), slice(None)] + [ - slice(0, (signal_size[i] - kernel.size(i) + 1), stride_[i - 2]) - for i in range(2, signal.ndim) - ] + crop_slices = tuple( + [slice(None), slice(None)] + [ + slice(0, (signal_size[i] - kernel.size(i) + 1), stride_[i - 2]) + for i in range(2, signal.ndim) + ] + ) output = output[crop_slices].contiguous() # Optionally, add a bias term before returning.