Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 1.0.1-dev
current_version = 1.0.1
commit = True
tag = False
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<release>[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(?P<build>[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?
Expand Down
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,6 @@ logs/
*curated*
*.pdf
*.tex
!images/fwd_computation.png
!images/fwd_computation.png
.python-version
.DS_Store
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@ Here are the other arguments and defaults used.
--wavelet Choice of wavelet. (default: Haar)
--max_level wavelet decomposition level (default: 4)
--log_scale Use log scaling for wavelets. (default: False)
--resize Additional resizing. (deafult: None)

**We conduct all the experiments with `Haar` wavelet with transformation/decomposition level of `4` for `256x256` image.**
**The choice of max_level is dependent on the image resolution to maintain sufficient spial and frequency information. For 256 image-level 4, 128 image-level 3 and so on.**
In future, we plan to release the jax-version of this code.

Citation
Expand Down
5 changes: 4 additions & 1 deletion scripts/wpkl/wpkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
import torchvision.transforms as tv
from tqdm import tqdm

from src.pytorchfwd.freq_math import compute_kl_divergence, forward_wavelet_packet_transform
from src.pytorchfwd.freq_math import (
compute_kl_divergence,
forward_wavelet_packet_transform,
)
from src.pytorchfwd.utils import ImagePathDataset, _parse_args

th.set_default_dtype(th.float64)
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
##########################
[metadata]
name = pytorchfwd
version = 1.0.1-dev
version = 1.0.1
description = Compute frecet wavelet distances
long_description = file: README.rst
long_description_content_type = text/x-rst
Expand Down
53 changes: 43 additions & 10 deletions src/pytorchfwd/fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import pathlib
from typing import List, Tuple
from typing import List, Tuple, Union

import numpy as np
import torch as th
Expand Down Expand Up @@ -61,7 +61,12 @@ def gpu_cov(tensor_):


def calculate_path_statistics(
path: str, wavelet: str, max_level: int, log_scale: bool, batch_size: int
path: str,
wavelet: str,
max_level: int,
log_scale: bool,
batch_size: int,
resize: Union[int, None],
) -> Tuple[np.ndarray, ...]:
"""Compute mean and sigma for given path.

Expand All @@ -71,6 +76,7 @@ def calculate_path_statistics(
max_level (int): Decomposition level.
log_scale (bool): Apply log scale.
batch_size (int): Batch size for packet decomposition.
resize (Union[int, None]): Optional resize option.

Raises:
ValueError: Error if mu and sigma cannot be calculated.
Expand All @@ -88,8 +94,13 @@ def calculate_path_statistics(
img_names = sorted(
[name for ext in IMAGE_EXTS for name in posfix_path.glob(f"*.{ext}")]
)
transfs_list = []
if resize is not None:
print(f"Resizing images to {(resize, resize)} resolution")
transfs_list.append(tv.Resize((resize, resize)))
transfs_list.append(tv.ToTensor())
dataloader = th.utils.data.DataLoader(
ImagePathDataset(img_names, transforms=tv.ToTensor()),
ImagePathDataset(img_names, transforms=tv.Compose(transfs_list)),
batch_size=batch_size,
shuffle=False,
drop_last=False,
Expand Down Expand Up @@ -122,7 +133,12 @@ def _compute_avg_frechet_distance(mu1, mu2, sigma1, sigma2):


def compute_fwd(
paths: List[str], wavelet: str, max_level: int, log_scale: bool, batch_size: int
paths: List[str],
wavelet: str,
max_level: int,
log_scale: bool,
batch_size: int,
resize: Union[int, None],
) -> float:
"""Compute Frechet Wavelet Distance.

Expand All @@ -132,6 +148,7 @@ def compute_fwd(
max_level (int): Decomposition level.
log_scale (bool): Apply log scale.
batch_size (int): Batch size for packet decomposition.
resize (Union[int, None]): Optional resize option.

Raises:
RuntimeError: Error if path doesn't exist.
Expand All @@ -145,19 +162,24 @@ def compute_fwd(

print(f"Computing stats for path: {paths[0]}")
mu_1, sigma_1 = calculate_path_statistics(
paths[0], wavelet, max_level, log_scale, batch_size
paths[0], wavelet, max_level, log_scale, batch_size, resize
)
print(f"Computing stats for path: {paths[1]}")
mu_2, sigma_2 = calculate_path_statistics(
paths[1], wavelet, max_level, log_scale, batch_size
paths[1], wavelet, max_level, log_scale, batch_size, resize
)

print("Computing Frechet distances for each packet.")
return _compute_avg_frechet_distance(mu_1, mu_2, sigma_1, sigma_2)


def _save_packets(
paths: List[str], wavelet: str, max_level: int, log_scale: bool, batch_size: int
paths: List[str],
wavelet: str,
max_level: int,
log_scale: bool,
batch_size: int,
resize: Union[int, None],
) -> None:
"""Save packets.

Expand All @@ -167,6 +189,7 @@ def _save_packets(
max_level (int): Decomposition level.
log_scale (bool): Apply log scale.
batch_size (int): Batch size for packet decomposition.
resize (Union[int, None]): Optional resize option.

Raises:
RuntimeError: Error if input path is invalid.
Expand All @@ -180,7 +203,7 @@ def _save_packets(

print(f"Computing stats for path: {paths[0]}")
mu_1, sigma_1 = calculate_path_statistics(
paths[0], wavelet, max_level, log_scale, batch_size
paths[0], wavelet, max_level, log_scale, batch_size, resize
)
np.savez_compressed(paths[1], mu=mu_1, sigma=sigma_1)

Expand All @@ -205,12 +228,22 @@ def main():
th.use_deterministic_algorithms(True)
if args.save_packets:
_save_packets(
args.path, args.wavelet, args.max_level, args.log_scale, args.batch_size
args.path,
args.wavelet,
args.max_level,
args.log_scale,
args.batch_size,
args.resize,
)
return

fwd = compute_fwd(
args.path, args.wavelet, args.max_level, args.log_scale, args.batch_size
args.path,
args.wavelet,
args.max_level,
args.log_scale,
args.batch_size,
args.resize,
)
print(f"FWD: {fwd}")

Expand Down
6 changes: 6 additions & 0 deletions src/pytorchfwd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def _parse_args():
parser.add_argument(
"--log_scale", action="store_true", help="Use log scaling for wavelets."
)
parser.add_argument(
"--resize",
type=int,
default=None,
help="Resize the images to specified resolution.",
)
parser.add_argument(
"--deterministic",
action="store_true",
Expand Down
2 changes: 1 addition & 1 deletion src/pytorchfwd/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Record the package version."""

VERSION = "1.0.1-dev"
VERSION = "1.0.1"