55import torch .nn .functional as f
66from torch import Tensor , nn
77from torch .fft import irfftn , rfftn
8+ from math import ceil , floor
89
910
1011def complex_matmul (a : Tensor , b : Tensor , groups : int = 1 ) -> Tensor :
@@ -55,7 +56,7 @@ def fft_conv(
5556 signal : Tensor ,
5657 kernel : Tensor ,
5758 bias : Tensor = None ,
58- padding : Union [int , Iterable [int ]] = 0 ,
59+ padding : Union [int , Iterable [int ], str ] = 0 ,
5960 padding_mode : str = "constant" ,
6061 stride : Union [int , Iterable [int ]] = 1 ,
6162 dilation : Union [int , Iterable [int ]] = 1 ,
@@ -69,19 +70,31 @@ def fft_conv(
6970 signal: (Tensor) Input tensor to be convolved with the kernel.
7071 kernel: (Tensor) Convolution kernel.
7172 bias: (Tensor) Bias tensor to add to the output.
72- padding: (Union[int, Iterable[int]) Number of zero samples to pad the
73- input on the last dimension.
73+ padding: (Union[int, Iterable[int], str) If int, Number of zero samples to pad then
74+ input on the last dimension. If str, "same" supported to pad input for size preservation.
75+ padding_mode: (str) Padding mode to use from {constant, reflection, replication}.
76+ reflection not available for 3d.
7477 stride: (Union[int, Iterable[int]) Stride size for computing output values.
78+ dilation: (Union[int, Iterable[int]) Dilation rate for the kernel.
79+ groups: (int) Number of groups for the convolution.
7580
7681 Returns:
7782 (Tensor) Convolved tensor
7883 """
7984
8085 # Cast padding, stride & dilation to tuples.
8186 n = signal .ndim - 2
82- padding_ = to_ntuple (padding , n = n )
8387 stride_ = to_ntuple (stride , n = n )
8488 dilation_ = to_ntuple (dilation , n = n )
89+ if isinstance (padding , str ):
90+ if padding == "same" :
91+ if stride != 1 or dilation != 1 :
92+ raise ValueError ("stride must be 1 for padding='same'." )
93+ padding_ = [(k - 1 ) / 2 for k in kernel .shape [2 :]]
94+ else :
95+ raise ValueError (f"Padding mode { padding } not supported." )
96+ else :
97+ padding_ = to_ntuple (padding , n = n )
8598
8699 # internal dilation offsets
87100 offset = torch .zeros (1 , 1 , * dilation_ , device = signal .device , dtype = signal .dtype )
@@ -93,36 +106,34 @@ def fft_conv(
93106 # pad the kernel internally according to the dilation parameters
94107 kernel = torch .kron (kernel , offset )[(slice (None ), slice (None )) + cutoff ]
95108
96- # Pad the input signal & kernel tensors
97- signal_padding = [p for p in padding_ [::- 1 ] for _ in range ( 2 )]
109+ # Pad the input signal & kernel tensors (round to support even sized convolutions)
110+ signal_padding = [r ( p ) for p in padding_ [::- 1 ] for r in ( floor , ceil )]
98111 signal = f .pad (signal , signal_padding , mode = padding_mode )
99112
100113 # Because PyTorch computes a *one-sided* FFT, we need the final dimension to
101114 # have *even* length. Just pad with one more zero if the final dimension is odd.
115+ signal_size = signal .size () # original signal size without padding to even
102116 if signal .size (- 1 ) % 2 != 0 :
103- signal_ = f .pad (signal , [0 , 1 ])
104- else :
105- signal_ = signal
117+ signal = f .pad (signal , [0 , 1 ])
106118
107119 kernel_padding = [
108120 pad
109- for i in reversed (range (2 , signal_ .ndim ))
110- for pad in [0 , signal_ .size (i ) - kernel .size (i )]
121+ for i in reversed (range (2 , signal .ndim ))
122+ for pad in [0 , signal .size (i ) - kernel .size (i )]
111123 ]
112124 padded_kernel = f .pad (kernel , kernel_padding )
113125
114126 # Perform fourier convolution -- FFT, matrix multiply, then IFFT
115- # signal_ = signal_.reshape(signal_.size(0), groups, -1, *signal_.shape[2:])
116- signal_fr = rfftn (signal_ , dim = tuple (range (2 , signal .ndim )))
117- kernel_fr = rfftn (padded_kernel , dim = tuple (range (2 , signal .ndim )))
127+ signal_fr = rfftn (signal .float (), dim = tuple (range (2 , signal .ndim )))
128+ kernel_fr = rfftn (padded_kernel .float (), dim = tuple (range (2 , signal .ndim )))
118129
119130 kernel_fr .imag *= - 1
120131 output_fr = complex_matmul (signal_fr , kernel_fr , groups = groups )
121132 output = irfftn (output_fr , dim = tuple (range (2 , signal .ndim )))
122133
123134 # Remove extra padded values
124- crop_slices = [slice (0 , output . size ( 0 )) , slice (0 , output . size ( 1 ) )] + [
125- slice (0 , (signal . size ( i ) - kernel .size (i ) + 1 ), stride_ [i - 2 ])
135+ crop_slices = [slice (None ) , slice (None )] + [
136+ slice (0 , (signal_size [ i ] - kernel .size (i ) + 1 ), stride_ [i - 2 ])
126137 for i in range (2 , signal .ndim )
127138 ]
128139 output = output [crop_slices ].contiguous ()
@@ -157,9 +168,14 @@ def __init__(
157168 out_channels: (int) Number of channels in output tensors
158169 kernel_size: (Union[int, Iterable[int]) Square radius of the kernel
159170 padding: (Union[int, Iterable[int]) Number of zero samples to pad the
160- input on the last dimension.
171+ input on the last dimension. If str, "same" supported to pad input for size preservation.
172+ padding_mode: (str) Padding mode to use from {constant, reflection, replication}.
173+ reflection not available for 3d.
161174 stride: (Union[int, Iterable[int]) Stride size for computing output values.
175+ dilation: (Union[int, Iterable[int]) Dilation rate for the kernel.
176+ groups: (int) Number of groups for the convolution.
162177 bias: (bool) If True, includes bias, which is added after convolution
178+ ndim: (int) Number of dimensions of the input tensor.
163179 """
164180 super ().__init__ ()
165181 self .in_channels = in_channels
0 commit comments