Skip to content

Commit 19b37fa

Browse files
committed
Add backward tests, split test files
1 parent 45fa3a2 commit 19b37fa

File tree

4 files changed

+252
-76
lines changed

4 files changed

+252
-76
lines changed

tests/test_fft_conv.py

Lines changed: 0 additions & 76 deletions
This file was deleted.

tests/test_functional.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from typing import Iterable, Union
2+
3+
import pytest
4+
import torch
5+
import torch.nn.functional as f
6+
7+
from fft_conv_pytorch.fft_conv import _FFTConv, fft_conv, to_ntuple
8+
from tests.utils import _assert_almost_equal, _gcd
9+
10+
11+
12+
@pytest.mark.parametrize("in_channels", [1, 2, 3])
13+
@pytest.mark.parametrize("out_channels", [1, 2, 3])
14+
@pytest.mark.parametrize("groups", [1, 2, 3])
15+
@pytest.mark.parametrize("kernel_size", [1, 2, 3])
16+
@pytest.mark.parametrize("padding", [0, 1])
17+
@pytest.mark.parametrize("stride", [1, 2, 3])
18+
@pytest.mark.parametrize("dilation", [1, 2, 3])
19+
@pytest.mark.parametrize("bias", [True, False])
20+
@pytest.mark.parametrize("ndim", [1, 2, 3])
21+
@pytest.mark.parametrize("input_size", [7, 8])
22+
def test_fft_conv_functional(
23+
in_channels: int,
24+
out_channels: int,
25+
kernel_size: Union[int, Iterable[int]],
26+
padding: Union[int, Iterable[int]],
27+
stride: Union[int, Iterable[int]],
28+
dilation: Union[int, Iterable[int]],
29+
groups: int,
30+
bias: bool,
31+
ndim: int,
32+
input_size: int,
33+
):
34+
torch_conv = getattr(f, f"conv{ndim}d")
35+
groups = _gcd(in_channels, _gcd(out_channels, groups))
36+
37+
batch_size = 2 # TODO: Make this non-constant?
38+
dims = ndim * [input_size]
39+
signal = torch.randn(batch_size, in_channels, *dims)
40+
kwargs = dict(
41+
bias=torch.randn(out_channels) if bias else None,
42+
padding=padding,
43+
stride=stride,
44+
dilation=dilation,
45+
groups=groups,
46+
)
47+
48+
kernel_size = to_ntuple(kernel_size, n=signal.ndim - 2)
49+
w0 = torch.randn(out_channels, in_channels // groups, *kernel_size,
50+
requires_grad=True)
51+
w1 = w0.detach().clone().requires_grad()
52+
53+
y0 = fft_conv(signal, w0, **kwargs)
54+
y1 = torch_conv(signal, w1, **kwargs)
55+
56+
_assert_almost_equal(y0, y1)
57+
58+
59+
@pytest.mark.parametrize("in_channels", [1, 2, 3])
60+
@pytest.mark.parametrize("out_channels", [1, 2, 3])
61+
@pytest.mark.parametrize("groups", [1, 2, 3])
62+
@pytest.mark.parametrize("kernel_size", [1, 2, 3])
63+
@pytest.mark.parametrize("padding", [0, 1])
64+
@pytest.mark.parametrize("stride", [1, 2, 3])
65+
@pytest.mark.parametrize("dilation", [1, 2, 3])
66+
@pytest.mark.parametrize("bias", [True, False])
67+
@pytest.mark.parametrize("ndim", [1, 2, 3])
68+
@pytest.mark.parametrize("input_size", [7, 8])
69+
def test_fft_conv_backward_functional(
70+
in_channels: int,
71+
out_channels: int,
72+
kernel_size: Union[int, Iterable[int]],
73+
padding: Union[int, Iterable[int]],
74+
stride: Union[int, Iterable[int]],
75+
dilation: Union[int, Iterable[int]],
76+
groups: int,
77+
bias: bool,
78+
ndim: int,
79+
input_size: int,
80+
):
81+
torch_conv = getattr(f, f"conv{ndim}d")
82+
groups = _gcd(in_channels, _gcd(out_channels, groups))
83+
84+
batch_size = 2 # TODO: Make this non-constant?
85+
dims = ndim * [input_size]
86+
signal = torch.randn(batch_size, in_channels, *dims)
87+
88+
kernel_size = to_ntuple(kernel_size, n=signal.ndim - 2)
89+
w0 = torch.randn(out_channels, in_channels // groups, *kernel_size,
90+
requires_grad=True)
91+
w1 = w0.detach().clone().requires_grad_()
92+
93+
b0 = torch.randn(out_channels, requires_grad=True) if bias else None
94+
b1 = b0.detach().clone().requires_grad_() if bias else None
95+
96+
kwargs = dict(
97+
padding=padding,
98+
stride=stride,
99+
dilation=dilation,
100+
groups=groups,
101+
)
102+
103+
y0 = fft_conv(signal, w0, bias=b0, **kwargs)
104+
y1 = torch_conv(signal, w1, bias=b1, **kwargs)
105+
106+
# Compute pseudo-loss and gradient
107+
y0.sum().backward()
108+
y1.sum().backward()
109+
110+
_assert_almost_equal(w0.grad, w1.grad)
111+
112+
if bias:
113+
_assert_almost_equal(b0.grad, b1.grad)

tests/test_module.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from typing import Iterable, Union
2+
3+
import pytest
4+
import torch
5+
import torch.nn.functional as f
6+
7+
from fft_conv_pytorch.fft_conv import _FFTConv, fft_conv
8+
from tests.utils import _assert_almost_equal, _gcd
9+
10+
11+
12+
@pytest.mark.parametrize("in_channels", [1, 2, 3])
13+
@pytest.mark.parametrize("out_channels", [1, 2, 3])
14+
@pytest.mark.parametrize("groups", [1, 2, 3])
15+
@pytest.mark.parametrize("kernel_size", [1, 2, 3])
16+
@pytest.mark.parametrize("padding", [0, 1])
17+
@pytest.mark.parametrize("stride", [1, 2, 3])
18+
@pytest.mark.parametrize("dilation", [1, 2, 3])
19+
@pytest.mark.parametrize("bias", [True, False])
20+
@pytest.mark.parametrize("ndim", [1, 2, 3])
21+
@pytest.mark.parametrize("input_size", [7, 8])
22+
def test_fft_conv_module(
23+
in_channels: int,
24+
out_channels: int,
25+
kernel_size: Union[int, Iterable[int]],
26+
padding: Union[int, Iterable[int]],
27+
stride: Union[int, Iterable[int]],
28+
dilation: Union[int, Iterable[int]],
29+
groups: int,
30+
bias: bool,
31+
ndim: int,
32+
input_size: int,
33+
):
34+
torch_conv = getattr(f, f"conv{ndim}d")
35+
groups = _gcd(in_channels, _gcd(out_channels, groups))
36+
fft_conv_layer = _FFTConv(
37+
in_channels=in_channels,
38+
out_channels=out_channels,
39+
kernel_size=kernel_size,
40+
padding=padding,
41+
stride=stride,
42+
dilation=1,
43+
groups=groups,
44+
bias=bias,
45+
ndim=ndim,
46+
)
47+
batch_size = 2 # TODO: Make this non-constant?
48+
dims = ndim * [input_size]
49+
signal = torch.randn(batch_size, in_channels, *dims)
50+
kwargs = dict(
51+
bias=fft_conv_layer.bias,
52+
padding=padding,
53+
stride=stride,
54+
dilation=dilation,
55+
groups=groups,
56+
)
57+
58+
y0 = fft_conv_layer(signal)
59+
y1 = torch_conv(signal, fft_conv_layer._weight, **kwargs)
60+
61+
_assert_almost_equal(y0, y1)
62+
63+
64+
@pytest.mark.parametrize("in_channels", [1, 2, 3])
65+
@pytest.mark.parametrize("out_channels", [1, 2, 3])
66+
@pytest.mark.parametrize("groups", [1, 2, 3])
67+
@pytest.mark.parametrize("kernel_size", [1, 2, 3])
68+
@pytest.mark.parametrize("padding", [0, 1])
69+
@pytest.mark.parametrize("stride", [1, 2, 3])
70+
@pytest.mark.parametrize("dilation", [1, 2, 3])
71+
@pytest.mark.parametrize("bias", [True, False])
72+
@pytest.mark.parametrize("ndim", [1, 2, 3])
73+
@pytest.mark.parametrize("input_size", [7, 8])
74+
def test_fft_conv_backward_module(
75+
in_channels: int,
76+
out_channels: int,
77+
kernel_size: Union[int, Iterable[int]],
78+
padding: Union[int, Iterable[int]],
79+
stride: Union[int, Iterable[int]],
80+
dilation: Union[int, Iterable[int]],
81+
groups: int,
82+
bias: bool,
83+
ndim: int,
84+
input_size: int,
85+
):
86+
torch_conv = getattr(f, f"conv{ndim}d")
87+
groups = _gcd(in_channels, _gcd(out_channels, groups))
88+
fft_conv_layer = _FFTConv(
89+
in_channels=in_channels,
90+
out_channels=out_channels,
91+
kernel_size=kernel_size,
92+
padding=padding,
93+
stride=stride,
94+
dilation=dilation,
95+
groups=groups,
96+
bias=bias,
97+
ndim=ndim,
98+
)
99+
batch_size = 2 # TODO: Make this non-constant?
100+
dims = ndim * [input_size]
101+
signal = torch.randn(batch_size, in_channels, *dims)
102+
103+
w0 = fft_conv_layer.weight
104+
w1 = w0.detach().clone().requires_grad_()
105+
b0 = fft_conv_layer.bias
106+
b1 = b0.detach().clone().requires_grad_() if bias else None
107+
108+
kwargs = dict(
109+
padding=padding,
110+
stride=stride,
111+
dilation=dilation,
112+
groups=groups,
113+
)
114+
115+
y0 = fft_conv_layer(signal)
116+
y1 = torch_conv(signal, w1, bias=b1, **kwargs)
117+
118+
# Compute pseudo-loss and gradient
119+
y0.sum().backward()
120+
y1.sum().backward()
121+
122+
_assert_almost_equal(w0.grad, w1.grad)
123+
if bias:
124+
_assert_almost_equal(b0.grad, b1.grad)

tests/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
from torch import Tensor
3+
4+
5+
def _assert_almost_equal(x: Tensor, y: Tensor) -> bool:
6+
abs_error = torch.abs(x - y)
7+
assert abs_error.mean().item() < 1e-5
8+
assert abs_error.max().item() < 1e-4
9+
return True
10+
11+
12+
def _gcd(x: int, y: int) -> int:
13+
while y:
14+
x, y = y, x % y
15+
return x

0 commit comments

Comments
 (0)