Skip to content

Commit 9cf2ccc

Browse files
committed
Added striding for FFT convolution operators.
1 parent 35c4bcc commit 9cf2ccc

File tree

2 files changed

+33
-15
lines changed

2 files changed

+33
-15
lines changed

benchmark.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,26 @@ def benchmark(fn: Callable, *args, num_iterations: int = 10, **kwargs) -> Benchm
2828
return Benchmark(np.mean(times[1:]).item(), np.std(times[1:]).item())
2929

3030

31-
def benchmark_conv(signal: Tensor, kernel: Tensor, bias: Tensor, padding: int = 0):
31+
def benchmark_conv(
32+
signal: Tensor, kernel: Tensor, bias: Tensor, padding: int = 0, stride: int = 1
33+
):
3234
print(f"Signal size: {signal.shape}")
3335
print(f"Kernel size: {kernel.shape}")
3436

3537
torch_conv = {1: f.conv1d, 2: f.conv2d, 3: f.conv3d}[signal.ndim - 2]
36-
direct_time = benchmark(torch_conv, signal, kernel, bias=bias, padding=padding)
37-
fourier_time = benchmark(fft_conv, signal, kernel, bias=bias, padding=padding)
38+
direct_time = benchmark(
39+
torch_conv, signal, kernel, bias=bias, padding=padding, stride=stride
40+
)
41+
fourier_time = benchmark(
42+
fft_conv, signal, kernel, bias=bias, padding=padding, stride=stride
43+
)
3844
print(f"Direct time: {direct_time}")
3945
print(f"Fourier time: {fourier_time}")
4046

41-
y0 = torch_conv(signal, kernel, bias=bias, padding=padding)
42-
y1 = fft_conv(signal, kernel, bias=bias, padding=padding)
47+
y0 = torch_conv(signal, kernel, bias=bias, padding=padding, stride=stride)
48+
y1 = fft_conv(signal, kernel, bias=bias, padding=padding, stride=stride)
4349
abs_error = torch.abs(y0 - y1)
50+
print(f"Output size: {y0.size()}")
4451
print(f"Abs Error Mean: {abs_error.mean():.3E}")
4552
print(f"Abs Error Std Dev: {abs_error.std():.3E}")
4653

fft_conv.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ def complex_matmul(a: Tensor, b: Tensor) -> Tensor:
2525

2626

2727
def fft_conv(
28-
signal: Tensor, kernel: Tensor, bias: Tensor = None, padding: int = 0,
28+
signal: Tensor,
29+
kernel: Tensor,
30+
bias: Tensor = None,
31+
padding: int = 0,
32+
stride: int = 1,
2933
) -> Tensor:
3034
"""Performs N-d convolution of Tensors using a fast fourier transform, which
3135
is very fast for large kernel sizes. Also, optionally adds a bias Tensor after
@@ -36,6 +40,7 @@ def fft_conv(
3640
kernel: (Tensor) Convolution kernel.
3741
bias: (Optional, Tensor) Bias tensor to add to the output.
3842
padding: (int) Number of zero samples to pad the input on the last dimension.
43+
stride: (int) Stride size for computing output values.
3944
4045
Returns:
4146
(Tensor) Convolved tensor
@@ -44,7 +49,8 @@ def fft_conv(
4449
signal_padding = (signal.ndim - 2) * [padding, padding]
4550
signal = f.pad(signal, signal_padding)
4651
kernel_padding = [
47-
pad for i in reversed(range(2, signal.ndim))
52+
pad
53+
for i in reversed(range(2, signal.ndim))
4854
for pad in [0, signal.size(i) - kernel.size(i)]
4955
]
5056
padded_kernel = f.pad(kernel, kernel_padding)
@@ -58,8 +64,9 @@ def fft_conv(
5864
output = irfftn(output_fr, dim=tuple(range(2, signal.ndim)))
5965

6066
# Remove extra padded values
61-
crop_slices = [slice(0, output.shape[0]), slice(0, output.shape[1])] + [
62-
slice(0, (signal.size(i) - kernel.size(i) + 1)) for i in range(2, signal.ndim)
67+
crop_slices = [slice(0, output.size(0)), slice(0, output.size(1))] + [
68+
slice(0, (signal.size(i) - kernel.size(i) + 1), stride)
69+
for i in range(2, signal.ndim)
6370
]
6471
output = output[crop_slices].contiguous()
6572

@@ -80,6 +87,7 @@ def __init__(
8087
out_channels: int,
8188
kernel_size: int,
8289
padding: int = 0,
90+
stride: int = 1,
8391
bias: bool = True,
8492
):
8593
"""
@@ -88,13 +96,15 @@ def __init__(
8896
out_channels: (int) Number of channels in output tensors
8997
kernel_size: (int) Square radius of the convolution kernel
9098
padding: (int) Amount of zero-padding to add to the input tensor
99+
stride: (int) Stride size for computing output values
91100
bias: (bool) If True, includes bias, which is added after convolution
92101
"""
93102
super().__init__()
94103
self.in_channels = in_channels
95104
self.out_channels = out_channels
96105
self.kernel_size = kernel_size
97106
self.padding = padding
107+
self.stride = stride
98108
self.use_bias = bias
99109

100110
self.weight = torch.empty(0)
@@ -106,6 +116,7 @@ def forward(self, signal):
106116
self.weight,
107117
bias=self.bias,
108118
padding=self.padding,
119+
stride=self.stride,
109120
)
110121

111122

@@ -116,14 +127,11 @@ def __init__(
116127
out_channels: int,
117128
kernel_size: int,
118129
padding: int = 0,
130+
stride: int = 1,
119131
bias: bool = True,
120132
):
121133
super().__init__(
122-
in_channels,
123-
out_channels,
124-
kernel_size,
125-
padding=padding,
126-
bias=bias,
134+
in_channels, out_channels, kernel_size, padding=padding, bias=bias,
127135
)
128136
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size))
129137

@@ -135,13 +143,15 @@ def __init__(
135143
out_channels: int,
136144
kernel_size: int,
137145
padding: int = 0,
146+
stride: int = 1,
138147
bias: bool = True,
139148
):
140149
super().__init__(
141150
in_channels,
142151
out_channels,
143152
kernel_size,
144153
padding=padding,
154+
stride=stride,
145155
bias=bias,
146156
)
147157
self.weight = nn.Parameter(
@@ -150,20 +160,21 @@ def __init__(
150160

151161

152162
class FFTConv3d(_FFTConv):
153-
154163
def __init__(
155164
self,
156165
in_channels: int,
157166
out_channels: int,
158167
kernel_size: int,
159168
padding: int = 0,
169+
stride: int = 1,
160170
bias: bool = True,
161171
):
162172
super().__init__(
163173
in_channels,
164174
out_channels,
165175
kernel_size,
166176
padding=padding,
177+
stride=stride,
167178
bias=bias,
168179
)
169180
self.weight = nn.Parameter(

0 commit comments

Comments
 (0)