@@ -57,6 +57,7 @@ def fft_conv(
5757 bias : Tensor = None ,
5858 padding : Union [int , Iterable [int ]] = 0 ,
5959 stride : Union [int , Iterable [int ]] = 1 ,
60+ dilation : Union [int , Iterable [int ]] = 1 ,
6061 groups : int = 1 ,
6162) -> Tensor :
6263 """Performs N-d convolution of Tensors using a fast fourier transform, which
@@ -74,9 +75,23 @@ def fft_conv(
7475 Returns:
7576 (Tensor) Convolved tensor
7677 """
77- # Cast padding & stride to tuples.
78- padding_ = to_ntuple (padding , n = signal .ndim - 2 )
79- stride_ = to_ntuple (stride , n = signal .ndim - 2 )
78+
79+ # Cast padding, stride & dilation to tuples.
80+ n = signal .ndim - 2
81+ padding_ = to_ntuple (padding , n = n )
82+ stride_ = to_ntuple (stride , n = n )
83+ dilation_ = to_ntuple (dilation , n = n )
84+
85+ # internal dilation offsets
86+ offset = torch .zeros (1 , 1 , * dilation_ )
87+ offset [(slice (None ), slice (None ), * ((0 ,) * n ))] = 1.
88+
89+ # correct the kernel by cutting off unwanted dilation trailing zeros
90+ cutoff = tuple (
91+ slice (None , - d + 1 if d != 1 else None ) for d in dilation_ )
92+
93+ # pad the kernel internally according to the dilation parameters
94+ kernel = torch .kron (kernel , offset )[(slice (None ), slice (None )) + cutoff ]
8095
8196 # Pad the input signal & kernel tensors
8297 signal_padding = [p for p in padding_ [::- 1 ] for _ in range (2 )]
@@ -167,21 +182,8 @@ def __init__(
167182 )
168183
169184 kernel_size = to_ntuple (kernel_size , ndim )
170- dilation = to_ntuple (dilation , ndim )
171- total_size = tuple (
172- ((ks - 1 ) * dil + 1 )
173- for ks , dil in zip (kernel_size , dilation )
174- )
175- weight = torch .zeros (out_channels , in_channels // groups , * total_size )
176- fill = torch .randn (out_channels , in_channels // groups , * kernel_size )
177- ids = tuple (
178- torch .arange (0 , tot_sz , dil )
179- for tot_sz , dil in zip (total_size , dilation )
180- )
181-
182- # workaround bc PyTorch doesn't support [:, :, <tensor tuple>] indexing
183- weight [(slice (None ), slice (None ),) + torch .meshgrid (* ids )] = fill
184-
185+ weight = torch .randn (out_channels , in_channels // groups , * kernel_size )
186+
185187 self .weight = nn .Parameter (weight )
186188 self .bias = nn .Parameter (torch .randn (out_channels )) if bias else None
187189
@@ -192,6 +194,7 @@ def forward(self, signal):
192194 bias = self .bias ,
193195 padding = self .padding ,
194196 stride = self .stride ,
197+ dilation = self .dilation ,
195198 groups = self .groups ,
196199 )
197200
0 commit comments