@@ -56,6 +56,7 @@ def fft_conv(
5656 kernel : Tensor ,
5757 bias : Tensor = None ,
5858 padding : Union [int , Iterable [int ]] = 0 ,
59+ padding_mode : str = 'constant' ,
5960 stride : Union [int , Iterable [int ]] = 1 ,
6061 dilation : Union [int , Iterable [int ]] = 1 ,
6162 groups : int = 1 ,
@@ -83,7 +84,7 @@ def fft_conv(
8384 dilation_ = to_ntuple (dilation , n = n )
8485
8586 # internal dilation offsets
86- offset = torch .zeros (1 , 1 , * dilation_ )
87+ offset = torch .zeros (1 , 1 , * dilation_ , device = signal . device , dtype = signal . dtype )
8788 offset [(slice (None ), slice (None ), * ((0 ,) * n ))] = 1.0
8889
8990 # correct the kernel by cutting off unwanted dilation trailing zeros
@@ -94,7 +95,7 @@ def fft_conv(
9495
9596 # Pad the input signal & kernel tensors
9697 signal_padding = [p for p in padding_ [::- 1 ] for _ in range (2 )]
97- signal = f .pad (signal , signal_padding )
98+ signal = f .pad (signal , signal_padding , mode = padding_mode )
9899
99100 # Because PyTorch computes a *one-sided* FFT, we need the final dimension to
100101 # have *even* length. Just pad with one more zero if the final dimension is odd.
@@ -143,6 +144,7 @@ def __init__(
143144 out_channels : int ,
144145 kernel_size : Union [int , Iterable [int ]],
145146 padding : Union [int , Iterable [int ]] = 0 ,
147+ padding_mode : str = 'constant' ,
146148 stride : Union [int , Iterable [int ]] = 1 ,
147149 dilation : Union [int , Iterable [int ]] = 1 ,
148150 groups : int = 1 ,
@@ -164,6 +166,7 @@ def __init__(
164166 self .out_channels = out_channels
165167 self .kernel_size = kernel_size
166168 self .padding = padding
169+ self .padding_mode = padding_mode
167170 self .stride = stride
168171 self .dilation = dilation
169172 self .groups = groups
@@ -192,6 +195,7 @@ def forward(self, signal):
192195 self .weight ,
193196 bias = self .bias ,
194197 padding = self .padding ,
198+ padding_mode = self .padding_mode ,
195199 stride = self .stride ,
196200 dilation = self .dilation ,
197201 groups = self .groups ,
0 commit comments