55import torch .nn .functional as f
66from torch import Tensor , nn
77from torch .fft import irfftn , rfftn
8+ from math import ceil , floor
89
910
1011def complex_matmul (a : Tensor , b : Tensor , groups : int = 1 ) -> Tensor :
@@ -55,7 +56,7 @@ def fft_conv(
5556 signal : Tensor ,
5657 kernel : Tensor ,
5758 bias : Tensor = None ,
58- padding : Union [int , Iterable [int ]] = 0 ,
59+ padding : Union [int , Iterable [int ], str ] = 0 ,
5960 padding_mode : str = "constant" ,
6061 stride : Union [int , Iterable [int ]] = 1 ,
6162 dilation : Union [int , Iterable [int ]] = 1 ,
@@ -69,19 +70,31 @@ def fft_conv(
6970 signal: (Tensor) Input tensor to be convolved with the kernel.
7071 kernel: (Tensor) Convolution kernel.
7172 bias: (Tensor) Bias tensor to add to the output.
72- padding: (Union[int, Iterable[int]) Number of zero samples to pad the
73- input on the last dimension.
73+ padding: (Union[int, Iterable[int], str) If int, Number of zero samples to pad then
74+ input on the last dimension. If str, "same" supported to pad input for size preservation.
75+ padding_mode: (str) Padding mode to use from {constant, reflection, replication}.
76+ reflection not available for 3d.
7477 stride: (Union[int, Iterable[int]) Stride size for computing output values.
78+ dilation: (Union[int, Iterable[int]) Dilation rate for the kernel.
79+ groups: (int) Number of groups for the convolution.
7580
7681 Returns:
7782 (Tensor) Convolved tensor
7883 """
7984
8085 # Cast padding, stride & dilation to tuples.
8186 n = signal .ndim - 2
82- padding_ = to_ntuple (padding , n = n )
8387 stride_ = to_ntuple (stride , n = n )
8488 dilation_ = to_ntuple (dilation , n = n )
89+ if isinstance (padding , str ):
90+ if padding == "same" :
91+ if stride != 1 or dilation != 1 :
92+ raise ValueError ("stride must be 1 for padding='same'." )
93+ padding_ = [(k - 1 ) / 2 for k in kernel .shape [2 :]]
94+ else :
95+ raise ValueError (f"Padding mode { padding } not supported." )
96+ else :
97+ padding_ = to_ntuple (padding , n = n )
8598
8699 # internal dilation offsets
87100 offset = torch .zeros (1 , 1 , * dilation_ , device = signal .device , dtype = signal .dtype )
@@ -93,8 +106,8 @@ def fft_conv(
93106 # pad the kernel internally according to the dilation parameters
94107 kernel = torch .kron (kernel , offset )[(slice (None ), slice (None )) + cutoff ]
95108
96- # Pad the input signal & kernel tensors
97- signal_padding = [p for p in padding_ [::- 1 ] for _ in range ( 2 )]
109+ # Pad the input signal & kernel tensors (round to support even sized convolutions)
110+ signal_padding = [r ( p ) for p in padding_ [::- 1 ] for r in ( floor , ceil )]
98111 signal = f .pad (signal , signal_padding , mode = padding_mode )
99112
100113 # Because PyTorch computes a *one-sided* FFT, we need the final dimension to
@@ -155,9 +168,14 @@ def __init__(
155168 out_channels: (int) Number of channels in output tensors
156169 kernel_size: (Union[int, Iterable[int]) Square radius of the kernel
157170 padding: (Union[int, Iterable[int]) Number of zero samples to pad the
158- input on the last dimension.
171+ input on the last dimension. If str, "same" supported to pad input for size preservation.
172+ padding_mode: (str) Padding mode to use from {constant, reflection, replication}.
173+ reflection not available for 3d.
159174 stride: (Union[int, Iterable[int]) Stride size for computing output values.
175+ dilation: (Union[int, Iterable[int]) Dilation rate for the kernel.
176+ groups: (int) Number of groups for the convolution.
160177 bias: (bool) If True, includes bias, which is added after convolution
178+ ndim: (int) Number of dimensions of the input tensor.
161179 """
162180 super ().__init__ ()
163181 self .in_channels = in_channels
0 commit comments