@@ -99,15 +99,14 @@ def fft_conv(
9999
100100 # Because PyTorch computes a *one-sided* FFT, we need the final dimension to
101101 # have *even* length. Just pad with one more zero if the final dimension is odd.
102+ signal_size = signal .size () # original signal size without padding to even
102103 if signal .size (- 1 ) % 2 != 0 :
103- signal_ = f .pad (signal , [0 , 1 ])
104- else :
105- signal_ = signal
104+ signal = f .pad (signal , [0 , 1 ])
106105
107106 kernel_padding = [
108107 pad
109- for i in reversed (range (2 , signal_ .ndim ))
110- for pad in [0 , signal_ .size (i ) - kernel .size (i )]
108+ for i in reversed (range (2 , signal .ndim ))
109+ for pad in [0 , signal .size (i ) - kernel .size (i )]
111110 ]
112111 padded_kernel = f .pad (kernel , kernel_padding )
113112
@@ -121,8 +120,8 @@ def fft_conv(
121120 output = irfftn (output_fr , dim = tuple (range (2 , signal .ndim )))
122121
123122 # Remove extra padded values
124- crop_slices = [slice (0 , output . size ( 0 )) , slice (0 , output . size ( 1 ) )] + [
125- slice (0 , (signal . size ( i ) - kernel .size (i ) + 1 ), stride_ [i - 2 ])
123+ crop_slices = [slice (None ) , slice (None )] + [
124+ slice (0 , (signal_size [ i ] - kernel .size (i ) + 1 ), stride_ [i - 2 ])
126125 for i in range (2 , signal .ndim )
127126 ]
128127 output = output [crop_slices ].contiguous ()
0 commit comments