Skip to content

Commit c4389ec

Browse files
committed
Pare down the number of unit tests for ease of development, apply 'isort' and 'black' for codestyle.
1 parent 3dcc29d commit c4389ec

File tree

4 files changed

+46
-78
lines changed

4 files changed

+46
-78
lines changed

fft_conv_pytorch/fft_conv.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,10 @@ def fft_conv(
8484

8585
# internal dilation offsets
8686
offset = torch.zeros(1, 1, *dilation_)
87-
offset[(slice(None), slice(None), *((0,) * n))] = 1.
87+
offset[(slice(None), slice(None), *((0,) * n))] = 1.0
8888

8989
# correct the kernel by cutting off unwanted dilation trailing zeros
90-
cutoff = tuple(
91-
slice(None, -d + 1 if d != 1 else None) for d in dilation_)
90+
cutoff = tuple(slice(None, -d + 1 if d != 1 else None) for d in dilation_)
9291

9392
# pad the kernel internally according to the dilation parameters
9493
kernel = torch.kron(kernel, offset)[(slice(None), slice(None)) + cutoff]

setup.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,8 @@ def get_version_tag() -> str:
2424
description="Implementation of 1D, 2D, and 3D FFT convolutions in PyTorch.",
2525
long_description=open("README.md").read(),
2626
long_description_content_type="text/markdown",
27-
install_requires=[
28-
"numpy",
29-
"torch>=1.7",
30-
],
31-
extras_require={
32-
"test": [
33-
"black",
34-
"flake8",
35-
"isort",
36-
"pytest",
37-
"pytest-cov",
38-
]
39-
},
27+
install_requires=["numpy", "torch>=1.7"],
28+
extras_require={"test": ["black", "flake8", "isort", "pytest", "pytest-cov"]},
4029
classifiers=[
4130
"Programming Language :: Python :: 3",
4231
"Operating System :: OS Independent",

tests/test_functional.py

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,18 @@
44
import torch
55
import torch.nn.functional as f
66

7-
from fft_conv_pytorch.fft_conv import _FFTConv, fft_conv, to_ntuple
7+
from fft_conv_pytorch.fft_conv import fft_conv, to_ntuple
88
from tests.utils import _assert_almost_equal, _gcd
99

1010

11-
12-
@pytest.mark.parametrize("in_channels", [1, 2, 3])
13-
@pytest.mark.parametrize("out_channels", [1, 2, 3])
11+
@pytest.mark.parametrize("in_channels", [2, 3])
12+
@pytest.mark.parametrize("out_channels", [2, 3])
1413
@pytest.mark.parametrize("groups", [1, 2, 3])
15-
@pytest.mark.parametrize("kernel_size", [1, 2, 3])
14+
@pytest.mark.parametrize("kernel_size", [2, 3])
1615
@pytest.mark.parametrize("padding", [0, 1])
17-
@pytest.mark.parametrize("stride", [1, 2, 3])
18-
@pytest.mark.parametrize("dilation", [1, 2, 3])
19-
@pytest.mark.parametrize("bias", [True, False])
16+
@pytest.mark.parametrize("stride", [1, 2])
17+
@pytest.mark.parametrize("dilation", [1, 2])
18+
@pytest.mark.parametrize("bias", [True])
2019
@pytest.mark.parametrize("ndim", [1, 2, 3])
2120
@pytest.mark.parametrize("input_size", [7, 8])
2221
def test_fft_conv_functional(
@@ -46,34 +45,30 @@ def test_fft_conv_functional(
4645
)
4746

4847
kernel_size = to_ntuple(kernel_size, n=signal.ndim - 2)
49-
w0 = torch.randn(out_channels, in_channels // groups, *kernel_size,
50-
requires_grad=True)
48+
w0 = torch.randn(
49+
out_channels, in_channels // groups, *kernel_size, requires_grad=True
50+
)
5151
w1 = w0.detach().clone().requires_grad_()
5252

5353
b0 = torch.randn(out_channels, requires_grad=True) if bias else None
5454
b1 = b0.detach().clone().requires_grad_() if bias else None
5555

56-
kwargs = dict(
57-
padding=padding,
58-
stride=stride,
59-
dilation=dilation,
60-
groups=groups,
61-
)
56+
kwargs = dict(padding=padding, stride=stride, dilation=dilation, groups=groups,)
6257

6358
y0 = fft_conv(signal, w0, bias=b0, **kwargs)
6459
y1 = torch_conv(signal, w1, bias=b1, **kwargs)
65-
60+
6661
_assert_almost_equal(y0, y1)
6762

6863

69-
@pytest.mark.parametrize("in_channels", [1, 2, 3])
70-
@pytest.mark.parametrize("out_channels", [1, 2, 3])
64+
@pytest.mark.parametrize("in_channels", [2, 3])
65+
@pytest.mark.parametrize("out_channels", [2, 3])
7166
@pytest.mark.parametrize("groups", [1, 2, 3])
72-
@pytest.mark.parametrize("kernel_size", [1, 2, 3])
67+
@pytest.mark.parametrize("kernel_size", [2, 3])
7368
@pytest.mark.parametrize("padding", [0, 1])
74-
@pytest.mark.parametrize("stride", [1, 2, 3])
75-
@pytest.mark.parametrize("dilation", [1, 2, 3])
76-
@pytest.mark.parametrize("bias", [True, False])
69+
@pytest.mark.parametrize("stride", [1, 2])
70+
@pytest.mark.parametrize("dilation", [1, 2])
71+
@pytest.mark.parametrize("bias", [True])
7772
@pytest.mark.parametrize("ndim", [1, 2, 3])
7873
@pytest.mark.parametrize("input_size", [7, 8])
7974
def test_fft_conv_backward_functional(
@@ -96,28 +91,24 @@ def test_fft_conv_backward_functional(
9691
signal = torch.randn(batch_size, in_channels, *dims)
9792

9893
kernel_size = to_ntuple(kernel_size, n=signal.ndim - 2)
99-
w0 = torch.randn(out_channels, in_channels // groups, *kernel_size,
100-
requires_grad=True)
94+
w0 = torch.randn(
95+
out_channels, in_channels // groups, *kernel_size, requires_grad=True
96+
)
10197
w1 = w0.detach().clone().requires_grad_()
102-
98+
10399
b0 = torch.randn(out_channels, requires_grad=True) if bias else None
104100
b1 = b0.detach().clone().requires_grad_() if bias else None
105101

106-
kwargs = dict(
107-
padding=padding,
108-
stride=stride,
109-
dilation=dilation,
110-
groups=groups,
111-
)
102+
kwargs = dict(padding=padding, stride=stride, dilation=dilation, groups=groups,)
112103

113104
y0 = fft_conv(signal, w0, bias=b0, **kwargs)
114105
y1 = torch_conv(signal, w1, bias=b1, **kwargs)
115-
106+
116107
# Compute pseudo-loss and gradient
117108
y0.sum().backward()
118109
y1.sum().backward()
119-
110+
120111
_assert_almost_equal(w0.grad, w1.grad)
121112

122-
if bias:
113+
if bias:
123114
_assert_almost_equal(b0.grad, b1.grad)

tests/test_module.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,18 @@
44
import torch
55
import torch.nn.functional as f
66

7-
from fft_conv_pytorch.fft_conv import _FFTConv, fft_conv
7+
from fft_conv_pytorch.fft_conv import _FFTConv
88
from tests.utils import _assert_almost_equal, _gcd
99

1010

11-
12-
@pytest.mark.parametrize("in_channels", [1, 2, 3])
13-
@pytest.mark.parametrize("out_channels", [1, 2, 3])
11+
@pytest.mark.parametrize("in_channels", [2, 3])
12+
@pytest.mark.parametrize("out_channels", [2, 3])
1413
@pytest.mark.parametrize("groups", [1, 2, 3])
15-
@pytest.mark.parametrize("kernel_size", [1, 2, 3])
14+
@pytest.mark.parametrize("kernel_size", [2, 3])
1615
@pytest.mark.parametrize("padding", [0, 1])
17-
@pytest.mark.parametrize("stride", [1, 2, 3])
18-
@pytest.mark.parametrize("dilation", [1, 2, 3])
19-
@pytest.mark.parametrize("bias", [True, False])
16+
@pytest.mark.parametrize("stride", [1, 2])
17+
@pytest.mark.parametrize("dilation", [1, 2])
18+
@pytest.mark.parametrize("bias", [True])
2019
@pytest.mark.parametrize("ndim", [1, 2, 3])
2120
@pytest.mark.parametrize("input_size", [7, 8])
2221
def test_fft_conv_module(
@@ -51,27 +50,22 @@ def test_fft_conv_module(
5150
weight = fft_conv_layer.weight
5251
bias = fft_conv_layer.bias
5352

54-
kwargs = dict(
55-
padding=padding,
56-
stride=stride,
57-
dilation=dilation,
58-
groups=groups,
59-
)
53+
kwargs = dict(padding=padding, stride=stride, dilation=dilation, groups=groups,)
6054

6155
y0 = fft_conv_layer(signal)
6256
y1 = torch_conv(signal, weight, bias=bias, **kwargs)
63-
57+
6458
_assert_almost_equal(y0, y1)
6559

6660

67-
@pytest.mark.parametrize("in_channels", [1, 2, 3])
68-
@pytest.mark.parametrize("out_channels", [1, 2, 3])
61+
@pytest.mark.parametrize("in_channels", [2, 3])
62+
@pytest.mark.parametrize("out_channels", [2, 3])
6963
@pytest.mark.parametrize("groups", [1, 2, 3])
70-
@pytest.mark.parametrize("kernel_size", [1, 2, 3])
64+
@pytest.mark.parametrize("kernel_size", [2, 3])
7165
@pytest.mark.parametrize("padding", [0, 1])
72-
@pytest.mark.parametrize("stride", [1, 2, 3])
73-
@pytest.mark.parametrize("dilation", [1, 2, 3])
74-
@pytest.mark.parametrize("bias", [True, False])
66+
@pytest.mark.parametrize("stride", [1, 2])
67+
@pytest.mark.parametrize("dilation", [1, 2])
68+
@pytest.mark.parametrize("bias", [True])
7569
@pytest.mark.parametrize("ndim", [1, 2, 3])
7670
@pytest.mark.parametrize("input_size", [7, 8])
7771
def test_fft_conv_backward_module(
@@ -108,12 +102,7 @@ def test_fft_conv_backward_module(
108102
b0 = fft_conv_layer.bias
109103
b1 = b0.detach().clone().requires_grad_() if bias else None
110104

111-
kwargs = dict(
112-
padding=padding,
113-
stride=stride,
114-
dilation=dilation,
115-
groups=groups,
116-
)
105+
kwargs = dict(padding=padding, stride=stride, dilation=dilation, groups=groups,)
117106

118107
y0 = fft_conv_layer(signal)
119108
y1 = torch_conv(signal, w1, bias=b1, **kwargs)

0 commit comments

Comments
 (0)