Skip to content

Commit 4abff90

Browse files
committed
Allow 'padding' and 'stride' to be given as tuples, rather than just integers. Effectively allows non-square kernels and striding.
1 parent 9cf2ccc commit 4abff90

File tree

1 file changed

+47
-76
lines changed

1 file changed

+47
-76
lines changed

fft_conv.py

Lines changed: 47 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from functools import partial
2+
from typing import Tuple, Union, Iterable
23

34
import torch
45
from torch import nn, Tensor
@@ -24,12 +25,33 @@ def complex_matmul(a: Tensor, b: Tensor) -> Tensor:
2425
return c
2526

2627

28+
def to_ntuple(val: Union[int, Iterable[int]], n: int) -> Tuple[int, ...]:
29+
"""Casts to a tuple with length 'n'. Useful for automatically computing the
30+
padding and stride for convolutions, where users may only provide an integer.
31+
32+
Args:
33+
val: (Union[int, Iterable[int]]) Value to cast into a tuple.
34+
n: (int) Desired length of the tuple
35+
36+
Returns:
37+
(Tuple[int, ...]) Tuple of length 'n'
38+
"""
39+
if isinstance(val, Iterable):
40+
out = tuple(val)
41+
if len(out) == n:
42+
return out
43+
else:
44+
raise ValueError(f"Cannot cast tuple of length {len(out)} to length {n}.")
45+
else:
46+
return n * (val,)
47+
48+
2749
def fft_conv(
2850
signal: Tensor,
2951
kernel: Tensor,
3052
bias: Tensor = None,
31-
padding: int = 0,
32-
stride: int = 1,
53+
padding: Union[int, Iterable[int]] = 0,
54+
stride: Union[int, Iterable[int]] = 1,
3355
) -> Tensor:
3456
"""Performs N-d convolution of Tensors using a fast fourier transform, which
3557
is very fast for large kernel sizes. Also, optionally adds a bias Tensor after
@@ -38,15 +60,20 @@ def fft_conv(
3860
Args:
3961
signal: (Tensor) Input tensor to be convolved with the kernel.
4062
kernel: (Tensor) Convolution kernel.
41-
bias: (Optional, Tensor) Bias tensor to add to the output.
42-
padding: (int) Number of zero samples to pad the input on the last dimension.
43-
stride: (int) Stride size for computing output values.
63+
bias: (Tensor) Bias tensor to add to the output.
64+
padding: (Union[int, Iterable[int]) Number of zero samples to pad the
65+
input on the last dimension.
66+
stride: (Union[int, Iterable[int]) Stride size for computing output values.
4467
4568
Returns:
4669
(Tensor) Convolved tensor
4770
"""
71+
# Cast padding & stride to tuples.
72+
padding_ = to_ntuple(padding, n=signal.ndim - 2)
73+
stride_ = to_ntuple(stride, n=signal.ndim - 2)
74+
4875
# Pad the input signal & kernel tensors
49-
signal_padding = (signal.ndim - 2) * [padding, padding]
76+
signal_padding = [p for p in padding_[::-1] for _ in range(2)]
5077
signal = f.pad(signal, signal_padding)
5178
kernel_padding = [
5279
pad
@@ -65,7 +92,7 @@ def fft_conv(
6592

6693
# Remove extra padded values
6794
crop_slices = [slice(0, output.size(0)), slice(0, output.size(1))] + [
68-
slice(0, (signal.size(i) - kernel.size(i) + 1), stride)
95+
slice(0, (signal.size(i) - kernel.size(i) + 1), stride_[i - 2])
6996
for i in range(2, signal.ndim)
7097
]
7198
output = output[crop_slices].contiguous()
@@ -85,18 +112,20 @@ def __init__(
85112
self,
86113
in_channels: int,
87114
out_channels: int,
88-
kernel_size: int,
89-
padding: int = 0,
90-
stride: int = 1,
115+
kernel_size: Union[int, Iterable[int]],
116+
padding: Union[int, Iterable[int]] = 0,
117+
stride: Union[int, Iterable[int]] = 1,
91118
bias: bool = True,
119+
ndim: int = 1,
92120
):
93121
"""
94122
Args:
95123
in_channels: (int) Number of channels in input tensors
96124
out_channels: (int) Number of channels in output tensors
97-
kernel_size: (int) Square radius of the convolution kernel
98-
padding: (int) Amount of zero-padding to add to the input tensor
99-
stride: (int) Stride size for computing output values
125+
kernel_size: (Union[int, Iterable[int]) Square radius of the kernel
126+
padding: (Union[int, Iterable[int]) Number of zero samples to pad the
127+
input on the last dimension.
128+
stride: (Union[int, Iterable[int]) Stride size for computing output values.
100129
bias: (bool) If True, includes bias, which is added after convolution
101130
"""
102131
super().__init__()
@@ -107,7 +136,8 @@ def __init__(
107136
self.stride = stride
108137
self.use_bias = bias
109138

110-
self.weight = torch.empty(0)
139+
kernel_size = to_ntuple(kernel_size, ndim)
140+
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, *kernel_size))
111141
self.bias = nn.Parameter(torch.randn(out_channels,)) if bias else None
112142

113143
def forward(self, signal):
@@ -120,65 +150,6 @@ def forward(self, signal):
120150
)
121151

122152

123-
class FFTConv1d(_FFTConv):
124-
def __init__(
125-
self,
126-
in_channels: int,
127-
out_channels: int,
128-
kernel_size: int,
129-
padding: int = 0,
130-
stride: int = 1,
131-
bias: bool = True,
132-
):
133-
super().__init__(
134-
in_channels, out_channels, kernel_size, padding=padding, bias=bias,
135-
)
136-
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size))
137-
138-
139-
class FFTConv2d(_FFTConv):
140-
def __init__(
141-
self,
142-
in_channels: int,
143-
out_channels: int,
144-
kernel_size: int,
145-
padding: int = 0,
146-
stride: int = 1,
147-
bias: bool = True,
148-
):
149-
super().__init__(
150-
in_channels,
151-
out_channels,
152-
kernel_size,
153-
padding=padding,
154-
stride=stride,
155-
bias=bias,
156-
)
157-
self.weight = nn.Parameter(
158-
torch.randn(out_channels, in_channels, kernel_size, kernel_size)
159-
)
160-
161-
162-
class FFTConv3d(_FFTConv):
163-
def __init__(
164-
self,
165-
in_channels: int,
166-
out_channels: int,
167-
kernel_size: int,
168-
padding: int = 0,
169-
stride: int = 1,
170-
bias: bool = True,
171-
):
172-
super().__init__(
173-
in_channels,
174-
out_channels,
175-
kernel_size,
176-
padding=padding,
177-
stride=stride,
178-
bias=bias,
179-
)
180-
self.weight = nn.Parameter(
181-
torch.randn(
182-
out_channels, in_channels, kernel_size, kernel_size, kernel_size
183-
)
184-
)
153+
FFTConv1d = partial(_FFTConv, ndim=1)
154+
FFTConv2d = partial(_FFTConv, ndim=2)
155+
FFTConv3d = partial(_FFTConv, ndim=3)

0 commit comments

Comments
 (0)