@@ -25,7 +25,11 @@ def complex_matmul(a: Tensor, b: Tensor) -> Tensor:
2525
2626
2727def fft_conv (
28- signal : Tensor , kernel : Tensor , bias : Tensor = None , padding : int = 0 ,
28+ signal : Tensor ,
29+ kernel : Tensor ,
30+ bias : Tensor = None ,
31+ padding : int = 0 ,
32+ stride : int = 1 ,
2933) -> Tensor :
3034 """Performs N-d convolution of Tensors using a fast fourier transform, which
3135 is very fast for large kernel sizes. Also, optionally adds a bias Tensor after
@@ -36,6 +40,7 @@ def fft_conv(
3640 kernel: (Tensor) Convolution kernel.
3741 bias: (Optional, Tensor) Bias tensor to add to the output.
3842 padding: (int) Number of zero samples to pad the input on the last dimension.
43+ stride: (int) Stride size for computing output values.
3944
4045 Returns:
4146 (Tensor) Convolved tensor
@@ -44,7 +49,8 @@ def fft_conv(
4449 signal_padding = (signal .ndim - 2 ) * [padding , padding ]
4550 signal = f .pad (signal , signal_padding )
4651 kernel_padding = [
47- pad for i in reversed (range (2 , signal .ndim ))
52+ pad
53+ for i in reversed (range (2 , signal .ndim ))
4854 for pad in [0 , signal .size (i ) - kernel .size (i )]
4955 ]
5056 padded_kernel = f .pad (kernel , kernel_padding )
@@ -58,8 +64,9 @@ def fft_conv(
5864 output = irfftn (output_fr , dim = tuple (range (2 , signal .ndim )))
5965
6066 # Remove extra padded values
61- crop_slices = [slice (0 , output .shape [0 ]), slice (0 , output .shape [1 ])] + [
62- slice (0 , (signal .size (i ) - kernel .size (i ) + 1 )) for i in range (2 , signal .ndim )
67+ crop_slices = [slice (0 , output .size (0 )), slice (0 , output .size (1 ))] + [
68+ slice (0 , (signal .size (i ) - kernel .size (i ) + 1 ), stride )
69+ for i in range (2 , signal .ndim )
6370 ]
6471 output = output [crop_slices ].contiguous ()
6572
@@ -80,6 +87,7 @@ def __init__(
8087 out_channels : int ,
8188 kernel_size : int ,
8289 padding : int = 0 ,
90+ stride : int = 1 ,
8391 bias : bool = True ,
8492 ):
8593 """
@@ -88,13 +96,15 @@ def __init__(
8896 out_channels: (int) Number of channels in output tensors
8997 kernel_size: (int) Square radius of the convolution kernel
9098 padding: (int) Amount of zero-padding to add to the input tensor
99+ stride: (int) Stride size for computing output values
91100 bias: (bool) If True, includes bias, which is added after convolution
92101 """
93102 super ().__init__ ()
94103 self .in_channels = in_channels
95104 self .out_channels = out_channels
96105 self .kernel_size = kernel_size
97106 self .padding = padding
107+ self .stride = stride
98108 self .use_bias = bias
99109
100110 self .weight = torch .empty (0 )
@@ -106,6 +116,7 @@ def forward(self, signal):
106116 self .weight ,
107117 bias = self .bias ,
108118 padding = self .padding ,
119+ stride = self .stride ,
109120 )
110121
111122
@@ -116,14 +127,11 @@ def __init__(
116127 out_channels : int ,
117128 kernel_size : int ,
118129 padding : int = 0 ,
130+ stride : int = 1 ,
119131 bias : bool = True ,
120132 ):
121133 super ().__init__ (
122- in_channels ,
123- out_channels ,
124- kernel_size ,
125- padding = padding ,
126- bias = bias ,
134+ in_channels , out_channels , kernel_size , padding = padding , bias = bias ,
127135 )
128136 self .weight = nn .Parameter (torch .randn (out_channels , in_channels , kernel_size ))
129137
@@ -135,13 +143,15 @@ def __init__(
135143 out_channels : int ,
136144 kernel_size : int ,
137145 padding : int = 0 ,
146+ stride : int = 1 ,
138147 bias : bool = True ,
139148 ):
140149 super ().__init__ (
141150 in_channels ,
142151 out_channels ,
143152 kernel_size ,
144153 padding = padding ,
154+ stride = stride ,
145155 bias = bias ,
146156 )
147157 self .weight = nn .Parameter (
@@ -150,20 +160,21 @@ def __init__(
150160
151161
152162class FFTConv3d (_FFTConv ):
153-
154163 def __init__ (
155164 self ,
156165 in_channels : int ,
157166 out_channels : int ,
158167 kernel_size : int ,
159168 padding : int = 0 ,
169+ stride : int = 1 ,
160170 bias : bool = True ,
161171 ):
162172 super ().__init__ (
163173 in_channels ,
164174 out_channels ,
165175 kernel_size ,
166176 padding = padding ,
177+ stride = stride ,
167178 bias = bias ,
168179 )
169180 self .weight = nn .Parameter (
0 commit comments