Skip to content

Commit a662103

Browse files
authored
Merge pull request #16 from fkodom/bug-fix/unit-test-utils
Bug Fix and Torch Compatibility
2 parents 833f271 + 6543868 commit a662103

File tree

8 files changed

+70
-45
lines changed

8 files changed

+70
-45
lines changed

.github/workflows/test.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
strategy:
1414
matrix:
1515
python: ["3.7", "3.8", "3.9"]
16-
torch: ["1.7", "1.8", "1.9", "1.10"]
16+
torch: ["1.8", "1.9", "1.10"]
1717

1818
steps:
1919
- name: Checkout
@@ -26,7 +26,7 @@ jobs:
2626

2727
- name: Install Package
2828
run: |
29-
pip install torch==${{ matrix.torch }}
29+
pip install torch~=${{ matrix.torch }}.0
3030
pip install .[test]
3131
3232
- name: Test

doc/scripts/generate_benchmark_plot.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from functools import lru_cache, partial
2-
from timeit import Timer
3-
from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Sequence, Union
2+
from typing import Dict, Iterable, List, Optional, Sequence, Union
43

54
import matplotlib.pyplot as plt
65
import numpy as np
@@ -9,25 +8,7 @@
98
from tqdm import tqdm
109

1110
from fft_conv_pytorch.fft_conv import fft_conv, to_ntuple
12-
13-
14-
class Benchmark(NamedTuple):
15-
mean: float
16-
std: float
17-
18-
def __repr__(self):
19-
return f"BenchmarkResult(mean: {self.mean:.3e}, std: {self.std:.3e})"
20-
21-
def __str__(self):
22-
return f"({self.mean:.3e} \u00B1 {self.std:.3e}) s"
23-
24-
25-
def benchmark(fn: Callable, *args, num_iterations: int = 10, **kwargs) -> Benchmark:
26-
timer = Timer(
27-
"fn(*args, **kwargs)", globals={"fn": fn, "args": args, "kwargs": kwargs},
28-
)
29-
times = timer.repeat(number=1, repeat=num_iterations + 1)
30-
return Benchmark(np.mean(times[1:]).item(), np.std(times[1:]).item())
11+
from fft_conv_pytorch.utils import Benchmark, benchmark
3112

3213

3314
@lru_cache(maxsize=1)

fft_conv_pytorch/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from timeit import Timer
2+
from typing import Callable, NamedTuple
3+
4+
import numpy as np
5+
import torch
6+
from torch import Tensor
7+
8+
9+
class Benchmark(NamedTuple):
10+
mean: float
11+
std: float
12+
13+
def __repr__(self):
14+
return f"BenchmarkResult(mean: {self.mean:.3e}, std: {self.std:.3e})"
15+
16+
def __str__(self):
17+
return f"({self.mean:.3e} \u00B1 {self.std:.3e}) s"
18+
19+
20+
def benchmark(fn: Callable, *args, num_iterations: int = 10, **kwargs) -> Benchmark:
21+
timer = Timer(
22+
"fn(*args, **kwargs)",
23+
globals={"fn": fn, "args": args, "kwargs": kwargs},
24+
)
25+
times = timer.repeat(number=1, repeat=num_iterations + 1)
26+
return Benchmark(np.mean(times[1:]).item(), np.std(times[1:]).item())
27+
28+
29+
def _assert_almost_equal(x: Tensor, y: Tensor) -> bool:
30+
abs_error = torch.abs(x - y)
31+
assert abs_error.mean().item() < 5e-5
32+
assert abs_error.max().item() < 1e-4
33+
return True
34+
35+
36+
def _gcd(x: int, y: int) -> int:
37+
while y:
38+
x, y = y, x % y
39+
return x

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def get_version_tag() -> str:
2424
description="Implementation of 1D, 2D, and 3D FFT convolutions in PyTorch.",
2525
long_description=open("README.md").read(),
2626
long_description_content_type="text/markdown",
27-
install_requires=["numpy", "torch>=1.7"],
27+
install_requires=["numpy", "torch>=1.8"],
2828
extras_require={"test": ["black", "flake8", "isort", "pytest", "pytest-cov"]},
2929
classifiers=[
3030
"Programming Language :: Python :: 3",

tests/__init__.py

Whitespace-only changes.

tests/test_functional.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.nn.functional as f
66

77
from fft_conv_pytorch.fft_conv import fft_conv, to_ntuple
8-
from tests.utils import _assert_almost_equal, _gcd
8+
from fft_conv_pytorch.utils import _assert_almost_equal, _gcd
99

1010

1111
@pytest.mark.parametrize("in_channels", [2, 3])
@@ -53,7 +53,12 @@ def test_fft_conv_functional(
5353
b0 = torch.randn(out_channels, requires_grad=True) if bias else None
5454
b1 = b0.detach().clone().requires_grad_() if bias else None
5555

56-
kwargs = dict(padding=padding, stride=stride, dilation=dilation, groups=groups,)
56+
kwargs = dict(
57+
padding=padding,
58+
stride=stride,
59+
dilation=dilation,
60+
groups=groups,
61+
)
5762

5863
y0 = fft_conv(signal, w0, bias=b0, **kwargs)
5964
y1 = torch_conv(signal, w1, bias=b1, **kwargs)
@@ -99,7 +104,12 @@ def test_fft_conv_backward_functional(
99104
b0 = torch.randn(out_channels, requires_grad=True) if bias else None
100105
b1 = b0.detach().clone().requires_grad_() if bias else None
101106

102-
kwargs = dict(padding=padding, stride=stride, dilation=dilation, groups=groups,)
107+
kwargs = dict(
108+
padding=padding,
109+
stride=stride,
110+
dilation=dilation,
111+
groups=groups,
112+
)
103113

104114
y0 = fft_conv(signal, w0, bias=b0, **kwargs)
105115
y1 = torch_conv(signal, w1, bias=b1, **kwargs)

tests/test_module.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.nn.functional as f
66

77
from fft_conv_pytorch.fft_conv import _FFTConv
8-
from tests.utils import _assert_almost_equal, _gcd
8+
from fft_conv_pytorch.utils import _assert_almost_equal, _gcd
99

1010

1111
@pytest.mark.parametrize("in_channels", [2, 3])
@@ -50,7 +50,12 @@ def test_fft_conv_module(
5050
weight = fft_conv_layer.weight
5151
bias = fft_conv_layer.bias
5252

53-
kwargs = dict(padding=padding, stride=stride, dilation=dilation, groups=groups,)
53+
kwargs = dict(
54+
padding=padding,
55+
stride=stride,
56+
dilation=dilation,
57+
groups=groups,
58+
)
5459

5560
y0 = fft_conv_layer(signal)
5661
y1 = torch_conv(signal, weight, bias=bias, **kwargs)
@@ -102,7 +107,12 @@ def test_fft_conv_backward_module(
102107
b0 = fft_conv_layer.bias
103108
b1 = b0.detach().clone().requires_grad_() if bias else None
104109

105-
kwargs = dict(padding=padding, stride=stride, dilation=dilation, groups=groups,)
110+
kwargs = dict(
111+
padding=padding,
112+
stride=stride,
113+
dilation=dilation,
114+
groups=groups,
115+
)
106116

107117
y0 = fft_conv_layer(signal)
108118
y1 = torch_conv(signal, w1, bias=b1, **kwargs)

tests/utils.py

Lines changed: 0 additions & 15 deletions
This file was deleted.

0 commit comments

Comments
 (0)