Skip to content

Commit ec6aa74

Browse files
committed
Refactor some code into 'fft_conv_pytorch/utils.py' as a shared resource
1 parent f46fe7c commit ec6aa74

File tree

4 files changed

+42
-38
lines changed

4 files changed

+42
-38
lines changed

doc/scripts/generate_benchmark_plot.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from functools import lru_cache, partial
2-
from timeit import Timer
3-
from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Sequence, Union
2+
from typing import Dict, Iterable, List, Optional, Sequence, Union
43

54
import matplotlib.pyplot as plt
65
import numpy as np
@@ -9,26 +8,7 @@
98
from tqdm import tqdm
109

1110
from fft_conv_pytorch.fft_conv import fft_conv, to_ntuple
12-
13-
14-
class Benchmark(NamedTuple):
15-
mean: float
16-
std: float
17-
18-
def __repr__(self):
19-
return f"BenchmarkResult(mean: {self.mean:.3e}, std: {self.std:.3e})"
20-
21-
def __str__(self):
22-
return f"({self.mean:.3e} \u00B1 {self.std:.3e}) s"
23-
24-
25-
def benchmark(fn: Callable, *args, num_iterations: int = 10, **kwargs) -> Benchmark:
26-
timer = Timer(
27-
"fn(*args, **kwargs)",
28-
globals={"fn": fn, "args": args, "kwargs": kwargs},
29-
)
30-
times = timer.repeat(number=1, repeat=num_iterations + 1)
31-
return Benchmark(np.mean(times[1:]).item(), np.std(times[1:]).item())
11+
from fft_conv_pytorch.utils import Benchmark, benchmark
3212

3313

3414
@lru_cache(maxsize=1)

fft_conv_pytorch/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from timeit import Timer
2+
from typing import Callable, NamedTuple
3+
4+
import numpy as np
5+
import torch
6+
from torch import Tensor
7+
8+
9+
class Benchmark(NamedTuple):
10+
mean: float
11+
std: float
12+
13+
def __repr__(self):
14+
return f"BenchmarkResult(mean: {self.mean:.3e}, std: {self.std:.3e})"
15+
16+
def __str__(self):
17+
return f"({self.mean:.3e} \u00B1 {self.std:.3e}) s"
18+
19+
20+
def benchmark(fn: Callable, *args, num_iterations: int = 10, **kwargs) -> Benchmark:
21+
timer = Timer(
22+
"fn(*args, **kwargs)",
23+
globals={"fn": fn, "args": args, "kwargs": kwargs},
24+
)
25+
times = timer.repeat(number=1, repeat=num_iterations + 1)
26+
return Benchmark(np.mean(times[1:]).item(), np.std(times[1:]).item())
27+
28+
29+
def _assert_almost_equal(x: Tensor, y: Tensor) -> bool:
30+
abs_error = torch.abs(x - y)
31+
assert abs_error.mean().item() < 5e-5
32+
assert abs_error.max().item() < 1e-4
33+
return True
34+
35+
36+
def _gcd(x: int, y: int) -> int:
37+
while y:
38+
x, y = y, x % y
39+
return x

tests/test_functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.nn.functional as f
66

77
from fft_conv_pytorch.fft_conv import fft_conv, to_ntuple
8-
from tests.utils import _assert_almost_equal, _gcd
8+
from fft_conv_pytorch.utils import _assert_almost_equal, _gcd
99

1010

1111
@pytest.mark.parametrize("in_channels", [2, 3])

tests/utils.py

Lines changed: 0 additions & 15 deletions
This file was deleted.

0 commit comments

Comments
 (0)