Skip to content

Commit 3dcc29d

Browse files
authored
Merge pull request #13 from aretor/master
Perform dilation with Kronecker product
2 parents c18e4b9 + 45eeae6 commit 3dcc29d

File tree

5 files changed

+286
-94
lines changed

5 files changed

+286
-94
lines changed

fft_conv_pytorch/fft_conv.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def fft_conv(
5757
bias: Tensor = None,
5858
padding: Union[int, Iterable[int]] = 0,
5959
stride: Union[int, Iterable[int]] = 1,
60+
dilation: Union[int, Iterable[int]] = 1,
6061
groups: int = 1,
6162
) -> Tensor:
6263
"""Performs N-d convolution of Tensors using a fast fourier transform, which
@@ -74,9 +75,23 @@ def fft_conv(
7475
Returns:
7576
(Tensor) Convolved tensor
7677
"""
77-
# Cast padding & stride to tuples.
78-
padding_ = to_ntuple(padding, n=signal.ndim - 2)
79-
stride_ = to_ntuple(stride, n=signal.ndim - 2)
78+
79+
# Cast padding, stride & dilation to tuples.
80+
n = signal.ndim - 2
81+
padding_ = to_ntuple(padding, n=n)
82+
stride_ = to_ntuple(stride, n=n)
83+
dilation_ = to_ntuple(dilation, n=n)
84+
85+
# internal dilation offsets
86+
offset = torch.zeros(1, 1, *dilation_)
87+
offset[(slice(None), slice(None), *((0,) * n))] = 1.
88+
89+
# correct the kernel by cutting off unwanted dilation trailing zeros
90+
cutoff = tuple(
91+
slice(None, -d + 1 if d != 1 else None) for d in dilation_)
92+
93+
# pad the kernel internally according to the dilation parameters
94+
kernel = torch.kron(kernel, offset)[(slice(None), slice(None)) + cutoff]
8095

8196
# Pad the input signal & kernel tensors
8297
signal_padding = [p for p in padding_[::-1] for _ in range(2)]
@@ -167,21 +182,8 @@ def __init__(
167182
)
168183

169184
kernel_size = to_ntuple(kernel_size, ndim)
170-
dilation = to_ntuple(dilation, ndim)
171-
total_size = tuple(
172-
((ks - 1) * dil + 1)
173-
for ks, dil in zip(kernel_size, dilation)
174-
)
175-
weight = torch.zeros(out_channels, in_channels // groups, *total_size)
176-
fill = torch.randn(out_channels, in_channels // groups, *kernel_size)
177-
ids = tuple(
178-
torch.arange(0, tot_sz, dil)
179-
for tot_sz, dil in zip(total_size, dilation)
180-
)
181-
182-
# workaround bc PyTorch doesn't support [:, :, <tensor tuple>] indexing
183-
weight[(slice(None), slice(None),) + torch.meshgrid(*ids)] = fill
184-
185+
weight = torch.randn(out_channels, in_channels // groups, *kernel_size)
186+
185187
self.weight = nn.Parameter(weight)
186188
self.bias = nn.Parameter(torch.randn(out_channels)) if bias else None
187189

@@ -192,6 +194,7 @@ def forward(self, signal):
192194
bias=self.bias,
193195
padding=self.padding,
194196
stride=self.stride,
197+
dilation=self.dilation,
195198
groups=self.groups,
196199
)
197200

tests/test_fft_conv.py

Lines changed: 0 additions & 76 deletions
This file was deleted.

tests/test_functional.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from typing import Iterable, Union
2+
3+
import pytest
4+
import torch
5+
import torch.nn.functional as f
6+
7+
from fft_conv_pytorch.fft_conv import _FFTConv, fft_conv, to_ntuple
8+
from tests.utils import _assert_almost_equal, _gcd
9+
10+
11+
12+
@pytest.mark.parametrize("in_channels", [1, 2, 3])
13+
@pytest.mark.parametrize("out_channels", [1, 2, 3])
14+
@pytest.mark.parametrize("groups", [1, 2, 3])
15+
@pytest.mark.parametrize("kernel_size", [1, 2, 3])
16+
@pytest.mark.parametrize("padding", [0, 1])
17+
@pytest.mark.parametrize("stride", [1, 2, 3])
18+
@pytest.mark.parametrize("dilation", [1, 2, 3])
19+
@pytest.mark.parametrize("bias", [True, False])
20+
@pytest.mark.parametrize("ndim", [1, 2, 3])
21+
@pytest.mark.parametrize("input_size", [7, 8])
22+
def test_fft_conv_functional(
23+
in_channels: int,
24+
out_channels: int,
25+
kernel_size: Union[int, Iterable[int]],
26+
padding: Union[int, Iterable[int]],
27+
stride: Union[int, Iterable[int]],
28+
dilation: Union[int, Iterable[int]],
29+
groups: int,
30+
bias: bool,
31+
ndim: int,
32+
input_size: int,
33+
):
34+
torch_conv = getattr(f, f"conv{ndim}d")
35+
groups = _gcd(in_channels, _gcd(out_channels, groups))
36+
37+
batch_size = 2 # TODO: Make this non-constant?
38+
dims = ndim * [input_size]
39+
signal = torch.randn(batch_size, in_channels, *dims)
40+
kwargs = dict(
41+
bias=torch.randn(out_channels) if bias else None,
42+
padding=padding,
43+
stride=stride,
44+
dilation=dilation,
45+
groups=groups,
46+
)
47+
48+
kernel_size = to_ntuple(kernel_size, n=signal.ndim - 2)
49+
w0 = torch.randn(out_channels, in_channels // groups, *kernel_size,
50+
requires_grad=True)
51+
w1 = w0.detach().clone().requires_grad_()
52+
53+
b0 = torch.randn(out_channels, requires_grad=True) if bias else None
54+
b1 = b0.detach().clone().requires_grad_() if bias else None
55+
56+
kwargs = dict(
57+
padding=padding,
58+
stride=stride,
59+
dilation=dilation,
60+
groups=groups,
61+
)
62+
63+
y0 = fft_conv(signal, w0, bias=b0, **kwargs)
64+
y1 = torch_conv(signal, w1, bias=b1, **kwargs)
65+
66+
_assert_almost_equal(y0, y1)
67+
68+
69+
@pytest.mark.parametrize("in_channels", [1, 2, 3])
70+
@pytest.mark.parametrize("out_channels", [1, 2, 3])
71+
@pytest.mark.parametrize("groups", [1, 2, 3])
72+
@pytest.mark.parametrize("kernel_size", [1, 2, 3])
73+
@pytest.mark.parametrize("padding", [0, 1])
74+
@pytest.mark.parametrize("stride", [1, 2, 3])
75+
@pytest.mark.parametrize("dilation", [1, 2, 3])
76+
@pytest.mark.parametrize("bias", [True, False])
77+
@pytest.mark.parametrize("ndim", [1, 2, 3])
78+
@pytest.mark.parametrize("input_size", [7, 8])
79+
def test_fft_conv_backward_functional(
80+
in_channels: int,
81+
out_channels: int,
82+
kernel_size: Union[int, Iterable[int]],
83+
padding: Union[int, Iterable[int]],
84+
stride: Union[int, Iterable[int]],
85+
dilation: Union[int, Iterable[int]],
86+
groups: int,
87+
bias: bool,
88+
ndim: int,
89+
input_size: int,
90+
):
91+
torch_conv = getattr(f, f"conv{ndim}d")
92+
groups = _gcd(in_channels, _gcd(out_channels, groups))
93+
94+
batch_size = 2 # TODO: Make this non-constant?
95+
dims = ndim * [input_size]
96+
signal = torch.randn(batch_size, in_channels, *dims)
97+
98+
kernel_size = to_ntuple(kernel_size, n=signal.ndim - 2)
99+
w0 = torch.randn(out_channels, in_channels // groups, *kernel_size,
100+
requires_grad=True)
101+
w1 = w0.detach().clone().requires_grad_()
102+
103+
b0 = torch.randn(out_channels, requires_grad=True) if bias else None
104+
b1 = b0.detach().clone().requires_grad_() if bias else None
105+
106+
kwargs = dict(
107+
padding=padding,
108+
stride=stride,
109+
dilation=dilation,
110+
groups=groups,
111+
)
112+
113+
y0 = fft_conv(signal, w0, bias=b0, **kwargs)
114+
y1 = torch_conv(signal, w1, bias=b1, **kwargs)
115+
116+
# Compute pseudo-loss and gradient
117+
y0.sum().backward()
118+
y1.sum().backward()
119+
120+
_assert_almost_equal(w0.grad, w1.grad)
121+
122+
if bias:
123+
_assert_almost_equal(b0.grad, b1.grad)

tests/test_module.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from typing import Iterable, Union
2+
3+
import pytest
4+
import torch
5+
import torch.nn.functional as f
6+
7+
from fft_conv_pytorch.fft_conv import _FFTConv, fft_conv
8+
from tests.utils import _assert_almost_equal, _gcd
9+
10+
11+
12+
@pytest.mark.parametrize("in_channels", [1, 2, 3])
13+
@pytest.mark.parametrize("out_channels", [1, 2, 3])
14+
@pytest.mark.parametrize("groups", [1, 2, 3])
15+
@pytest.mark.parametrize("kernel_size", [1, 2, 3])
16+
@pytest.mark.parametrize("padding", [0, 1])
17+
@pytest.mark.parametrize("stride", [1, 2, 3])
18+
@pytest.mark.parametrize("dilation", [1, 2, 3])
19+
@pytest.mark.parametrize("bias", [True, False])
20+
@pytest.mark.parametrize("ndim", [1, 2, 3])
21+
@pytest.mark.parametrize("input_size", [7, 8])
22+
def test_fft_conv_module(
23+
in_channels: int,
24+
out_channels: int,
25+
kernel_size: Union[int, Iterable[int]],
26+
padding: Union[int, Iterable[int]],
27+
stride: Union[int, Iterable[int]],
28+
dilation: Union[int, Iterable[int]],
29+
groups: int,
30+
bias: bool,
31+
ndim: int,
32+
input_size: int,
33+
):
34+
torch_conv = getattr(f, f"conv{ndim}d")
35+
groups = _gcd(in_channels, _gcd(out_channels, groups))
36+
fft_conv_layer = _FFTConv(
37+
in_channels=in_channels,
38+
out_channels=out_channels,
39+
kernel_size=kernel_size,
40+
padding=padding,
41+
stride=stride,
42+
dilation=dilation,
43+
groups=groups,
44+
bias=bias,
45+
ndim=ndim,
46+
)
47+
batch_size = 2 # TODO: Make this non-constant?
48+
dims = ndim * [input_size]
49+
signal = torch.randn(batch_size, in_channels, *dims)
50+
51+
weight = fft_conv_layer.weight
52+
bias = fft_conv_layer.bias
53+
54+
kwargs = dict(
55+
padding=padding,
56+
stride=stride,
57+
dilation=dilation,
58+
groups=groups,
59+
)
60+
61+
y0 = fft_conv_layer(signal)
62+
y1 = torch_conv(signal, weight, bias=bias, **kwargs)
63+
64+
_assert_almost_equal(y0, y1)
65+
66+
67+
@pytest.mark.parametrize("in_channels", [1, 2, 3])
68+
@pytest.mark.parametrize("out_channels", [1, 2, 3])
69+
@pytest.mark.parametrize("groups", [1, 2, 3])
70+
@pytest.mark.parametrize("kernel_size", [1, 2, 3])
71+
@pytest.mark.parametrize("padding", [0, 1])
72+
@pytest.mark.parametrize("stride", [1, 2, 3])
73+
@pytest.mark.parametrize("dilation", [1, 2, 3])
74+
@pytest.mark.parametrize("bias", [True, False])
75+
@pytest.mark.parametrize("ndim", [1, 2, 3])
76+
@pytest.mark.parametrize("input_size", [7, 8])
77+
def test_fft_conv_backward_module(
78+
in_channels: int,
79+
out_channels: int,
80+
kernel_size: Union[int, Iterable[int]],
81+
padding: Union[int, Iterable[int]],
82+
stride: Union[int, Iterable[int]],
83+
dilation: Union[int, Iterable[int]],
84+
groups: int,
85+
bias: bool,
86+
ndim: int,
87+
input_size: int,
88+
):
89+
torch_conv = getattr(f, f"conv{ndim}d")
90+
groups = _gcd(in_channels, _gcd(out_channels, groups))
91+
fft_conv_layer = _FFTConv(
92+
in_channels=in_channels,
93+
out_channels=out_channels,
94+
kernel_size=kernel_size,
95+
padding=padding,
96+
stride=stride,
97+
dilation=dilation,
98+
groups=groups,
99+
bias=bias,
100+
ndim=ndim,
101+
)
102+
batch_size = 2 # TODO: Make this non-constant?
103+
dims = ndim * [input_size]
104+
signal = torch.randn(batch_size, in_channels, *dims)
105+
106+
w0 = fft_conv_layer.weight
107+
w1 = w0.detach().clone().requires_grad_()
108+
b0 = fft_conv_layer.bias
109+
b1 = b0.detach().clone().requires_grad_() if bias else None
110+
111+
kwargs = dict(
112+
padding=padding,
113+
stride=stride,
114+
dilation=dilation,
115+
groups=groups,
116+
)
117+
118+
y0 = fft_conv_layer(signal)
119+
y1 = torch_conv(signal, w1, bias=b1, **kwargs)
120+
121+
# Compute pseudo-loss and gradient
122+
y0.sum().backward()
123+
y1.sum().backward()
124+
125+
_assert_almost_equal(w0.grad, w1.grad)
126+
if bias:
127+
_assert_almost_equal(b0.grad, b1.grad)

0 commit comments

Comments
 (0)