Skip to content

Commit a6a1b86

Browse files
authored
Merge pull request #34 from BonnBytes/resize
Add resize
2 parents ee2e479 + 1e78f2e commit a6a1b86

File tree

8 files changed

+61
-15
lines changed

8 files changed

+61
-15
lines changed

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 1.0.1-dev
2+
current_version = 1.0.1
33
commit = True
44
tag = False
55
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<release>[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(?P<build>[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,6 @@ logs/
172172
*curated*
173173
*.pdf
174174
*.tex
175-
!images/fwd_computation.png
175+
!images/fwd_computation.png
176+
.python-version
177+
.DS_Store

README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,10 @@ Here are the other arguments and defaults used.
7575
--wavelet Choice of wavelet. (default: Haar)
7676
--max_level wavelet decomposition level (default: 4)
7777
--log_scale Use log scaling for wavelets. (default: False)
78+
--resize Additional resizing. (deafult: None)
7879
7980
**We conduct all the experiments with `Haar` wavelet with transformation/decomposition level of `4` for `256x256` image.**
81+
**The choice of max_level is dependent on the image resolution to maintain sufficient spial and frequency information. For 256 image-level 4, 128 image-level 3 and so on.**
8082
In future, we plan to release the jax-version of this code.
8183

8284
Citation

scripts/wpkl/wpkl.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
import torchvision.transforms as tv
99
from tqdm import tqdm
1010

11-
from src.pytorchfwd.freq_math import compute_kl_divergence, forward_wavelet_packet_transform
11+
from src.pytorchfwd.freq_math import (
12+
compute_kl_divergence,
13+
forward_wavelet_packet_transform,
14+
)
1215
from src.pytorchfwd.utils import ImagePathDataset, _parse_args
1316

1417
th.set_default_dtype(th.float64)

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
##########################
44
[metadata]
55
name = pytorchfwd
6-
version = 1.0.1-dev
6+
version = 1.0.1
77
description = Compute frecet wavelet distances
88
long_description = file: README.rst
99
long_description_content_type = text/x-rst

src/pytorchfwd/fwd.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import os
44
import pathlib
5-
from typing import List, Tuple
5+
from typing import List, Tuple, Union
66

77
import numpy as np
88
import torch as th
@@ -61,7 +61,12 @@ def gpu_cov(tensor_):
6161

6262

6363
def calculate_path_statistics(
64-
path: str, wavelet: str, max_level: int, log_scale: bool, batch_size: int
64+
path: str,
65+
wavelet: str,
66+
max_level: int,
67+
log_scale: bool,
68+
batch_size: int,
69+
resize: Union[int, None],
6570
) -> Tuple[np.ndarray, ...]:
6671
"""Compute mean and sigma for given path.
6772
@@ -71,6 +76,7 @@ def calculate_path_statistics(
7176
max_level (int): Decomposition level.
7277
log_scale (bool): Apply log scale.
7378
batch_size (int): Batch size for packet decomposition.
79+
resize (Union[int, None]): Optional resize option.
7480
7581
Raises:
7682
ValueError: Error if mu and sigma cannot be calculated.
@@ -88,8 +94,13 @@ def calculate_path_statistics(
8894
img_names = sorted(
8995
[name for ext in IMAGE_EXTS for name in posfix_path.glob(f"*.{ext}")]
9096
)
97+
transfs_list = []
98+
if resize is not None:
99+
print(f"Resizing images to {(resize, resize)} resolution")
100+
transfs_list.append(tv.Resize((resize, resize)))
101+
transfs_list.append(tv.ToTensor())
91102
dataloader = th.utils.data.DataLoader(
92-
ImagePathDataset(img_names, transforms=tv.ToTensor()),
103+
ImagePathDataset(img_names, transforms=tv.Compose(transfs_list)),
93104
batch_size=batch_size,
94105
shuffle=False,
95106
drop_last=False,
@@ -122,7 +133,12 @@ def _compute_avg_frechet_distance(mu1, mu2, sigma1, sigma2):
122133

123134

124135
def compute_fwd(
125-
paths: List[str], wavelet: str, max_level: int, log_scale: bool, batch_size: int
136+
paths: List[str],
137+
wavelet: str,
138+
max_level: int,
139+
log_scale: bool,
140+
batch_size: int,
141+
resize: Union[int, None],
126142
) -> float:
127143
"""Compute Frechet Wavelet Distance.
128144
@@ -132,6 +148,7 @@ def compute_fwd(
132148
max_level (int): Decomposition level.
133149
log_scale (bool): Apply log scale.
134150
batch_size (int): Batch size for packet decomposition.
151+
resize (Union[int, None]): Optional resize option.
135152
136153
Raises:
137154
RuntimeError: Error if path doesn't exist.
@@ -145,19 +162,24 @@ def compute_fwd(
145162

146163
print(f"Computing stats for path: {paths[0]}")
147164
mu_1, sigma_1 = calculate_path_statistics(
148-
paths[0], wavelet, max_level, log_scale, batch_size
165+
paths[0], wavelet, max_level, log_scale, batch_size, resize
149166
)
150167
print(f"Computing stats for path: {paths[1]}")
151168
mu_2, sigma_2 = calculate_path_statistics(
152-
paths[1], wavelet, max_level, log_scale, batch_size
169+
paths[1], wavelet, max_level, log_scale, batch_size, resize
153170
)
154171

155172
print("Computing Frechet distances for each packet.")
156173
return _compute_avg_frechet_distance(mu_1, mu_2, sigma_1, sigma_2)
157174

158175

159176
def _save_packets(
160-
paths: List[str], wavelet: str, max_level: int, log_scale: bool, batch_size: int
177+
paths: List[str],
178+
wavelet: str,
179+
max_level: int,
180+
log_scale: bool,
181+
batch_size: int,
182+
resize: Union[int, None],
161183
) -> None:
162184
"""Save packets.
163185
@@ -167,6 +189,7 @@ def _save_packets(
167189
max_level (int): Decomposition level.
168190
log_scale (bool): Apply log scale.
169191
batch_size (int): Batch size for packet decomposition.
192+
resize (Union[int, None]): Optional resize option.
170193
171194
Raises:
172195
RuntimeError: Error if input path is invalid.
@@ -180,7 +203,7 @@ def _save_packets(
180203

181204
print(f"Computing stats for path: {paths[0]}")
182205
mu_1, sigma_1 = calculate_path_statistics(
183-
paths[0], wavelet, max_level, log_scale, batch_size
206+
paths[0], wavelet, max_level, log_scale, batch_size, resize
184207
)
185208
np.savez_compressed(paths[1], mu=mu_1, sigma=sigma_1)
186209

@@ -205,12 +228,22 @@ def main():
205228
th.use_deterministic_algorithms(True)
206229
if args.save_packets:
207230
_save_packets(
208-
args.path, args.wavelet, args.max_level, args.log_scale, args.batch_size
231+
args.path,
232+
args.wavelet,
233+
args.max_level,
234+
args.log_scale,
235+
args.batch_size,
236+
args.resize,
209237
)
210238
return
211239

212240
fwd = compute_fwd(
213-
args.path, args.wavelet, args.max_level, args.log_scale, args.batch_size
241+
args.path,
242+
args.wavelet,
243+
args.max_level,
244+
args.log_scale,
245+
args.batch_size,
246+
args.resize,
214247
)
215248
print(f"FWD: {fwd}")
216249

src/pytorchfwd/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ def _parse_args():
3030
parser.add_argument(
3131
"--log_scale", action="store_true", help="Use log scaling for wavelets."
3232
)
33+
parser.add_argument(
34+
"--resize",
35+
type=int,
36+
default=None,
37+
help="Resize the images to specified resolution.",
38+
)
3339
parser.add_argument(
3440
"--deterministic",
3541
action="store_true",

src/pytorchfwd/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Record the package version."""
22

3-
VERSION = "1.0.1-dev"
3+
VERSION = "1.0.1"

0 commit comments

Comments
 (0)