Skip to content

Commit 22e21c7

Browse files
committed
Perform dilation with Kronecker product
1 parent 19b37fa commit 22e21c7

File tree

1 file changed

+21
-18
lines changed

1 file changed

+21
-18
lines changed

fft_conv_pytorch/fft_conv.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def fft_conv(
5757
bias: Tensor = None,
5858
padding: Union[int, Iterable[int]] = 0,
5959
stride: Union[int, Iterable[int]] = 1,
60+
dilation: Union[int, Iterable[int]] = 1,
6061
groups: int = 1,
6162
) -> Tensor:
6263
"""Performs N-d convolution of Tensors using a fast fourier transform, which
@@ -74,9 +75,23 @@ def fft_conv(
7475
Returns:
7576
(Tensor) Convolved tensor
7677
"""
77-
# Cast padding & stride to tuples.
78-
padding_ = to_ntuple(padding, n=signal.ndim - 2)
79-
stride_ = to_ntuple(stride, n=signal.ndim - 2)
78+
79+
# Cast padding, stride & dilation to tuples.
80+
n = signal.ndim - 2
81+
padding_ = to_ntuple(padding, n=n)
82+
stride_ = to_ntuple(stride, n=n)
83+
dilation_ = to_ntuple(dilation, n=n)
84+
85+
# internal dilation offsets
86+
offset = torch.zeros(1, 1, *dilation_)
87+
offset[(slice(None), slice(None), *((0,) * n))] = 1.
88+
89+
# correct the kernel by cutting off unwanted dilation trailing zeros
90+
cutoff = tuple(
91+
slice(None, -d + 1 if d != 1 else None) for d in dilation_)
92+
93+
# pad the kernel internally according to the dilation parameters
94+
kernel = torch.kron(kernel, offset)[(slice(None), slice(None)) + cutoff]
8095

8196
# Pad the input signal & kernel tensors
8297
signal_padding = [p for p in padding_[::-1] for _ in range(2)]
@@ -167,21 +182,8 @@ def __init__(
167182
)
168183

169184
kernel_size = to_ntuple(kernel_size, ndim)
170-
dilation = to_ntuple(dilation, ndim)
171-
total_size = tuple(
172-
((ks - 1) * dil + 1)
173-
for ks, dil in zip(kernel_size, dilation)
174-
)
175-
weight = torch.zeros(out_channels, in_channels // groups, *total_size)
176-
fill = torch.randn(out_channels, in_channels // groups, *kernel_size)
177-
ids = tuple(
178-
torch.arange(0, tot_sz, dil)
179-
for tot_sz, dil in zip(total_size, dilation)
180-
)
181-
182-
# workaround bc PyTorch doesn't support [:, :, <tensor tuple>] indexing
183-
weight[(slice(None), slice(None),) + torch.meshgrid(*ids)] = fill
184-
185+
weight = torch.randn(out_channels, in_channels // groups, *kernel_size)
186+
185187
self.weight = nn.Parameter(weight)
186188
self.bias = nn.Parameter(torch.randn(out_channels)) if bias else None
187189

@@ -192,6 +194,7 @@ def forward(self, signal):
192194
bias=self.bias,
193195
padding=self.padding,
194196
stride=self.stride,
197+
dilation=self.dilation,
195198
groups=self.groups,
196199
)
197200

0 commit comments

Comments
 (0)