diff --git a/.bumpversion.cfg b/.bumpversion.cfg index d1f5e1a..e319ce2 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 1.0.1-dev +current_version = 1.0.1 commit = True tag = False parse = (?P\d+)\.(?P\d+)\.(?P\d+)(?:-(?P[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(?P[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))? diff --git a/.gitignore b/.gitignore index cdd6d26..b6a152b 100644 --- a/.gitignore +++ b/.gitignore @@ -172,4 +172,6 @@ logs/ *curated* *.pdf *.tex -!images/fwd_computation.png \ No newline at end of file +!images/fwd_computation.png +.python-version +.DS_Store \ No newline at end of file diff --git a/README.rst b/README.rst index 1d4ab86..2e49826 100644 --- a/README.rst +++ b/README.rst @@ -75,8 +75,10 @@ Here are the other arguments and defaults used. --wavelet Choice of wavelet. (default: Haar) --max_level wavelet decomposition level (default: 4) --log_scale Use log scaling for wavelets. (default: False) + --resize Additional resizing. (deafult: None) **We conduct all the experiments with `Haar` wavelet with transformation/decomposition level of `4` for `256x256` image.** +**The choice of max_level is dependent on the image resolution to maintain sufficient spial and frequency information. For 256 image-level 4, 128 image-level 3 and so on.** In future, we plan to release the jax-version of this code. Citation diff --git a/scripts/wpkl/wpkl.py b/scripts/wpkl/wpkl.py index 1e4cc2d..d057c1c 100644 --- a/scripts/wpkl/wpkl.py +++ b/scripts/wpkl/wpkl.py @@ -8,7 +8,10 @@ import torchvision.transforms as tv from tqdm import tqdm -from src.pytorchfwd.freq_math import compute_kl_divergence, forward_wavelet_packet_transform +from src.pytorchfwd.freq_math import ( + compute_kl_divergence, + forward_wavelet_packet_transform, +) from src.pytorchfwd.utils import ImagePathDataset, _parse_args th.set_default_dtype(th.float64) diff --git a/setup.cfg b/setup.cfg index 7a61201..dfdd5a6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,7 +3,7 @@ ########################## [metadata] name = pytorchfwd -version = 1.0.1-dev +version = 1.0.1 description = Compute frecet wavelet distances long_description = file: README.rst long_description_content_type = text/x-rst diff --git a/src/pytorchfwd/fwd.py b/src/pytorchfwd/fwd.py index bf0bb17..565db3e 100644 --- a/src/pytorchfwd/fwd.py +++ b/src/pytorchfwd/fwd.py @@ -2,7 +2,7 @@ import os import pathlib -from typing import List, Tuple +from typing import List, Tuple, Union import numpy as np import torch as th @@ -61,7 +61,12 @@ def gpu_cov(tensor_): def calculate_path_statistics( - path: str, wavelet: str, max_level: int, log_scale: bool, batch_size: int + path: str, + wavelet: str, + max_level: int, + log_scale: bool, + batch_size: int, + resize: Union[int, None], ) -> Tuple[np.ndarray, ...]: """Compute mean and sigma for given path. @@ -71,6 +76,7 @@ def calculate_path_statistics( max_level (int): Decomposition level. log_scale (bool): Apply log scale. batch_size (int): Batch size for packet decomposition. + resize (Union[int, None]): Optional resize option. Raises: ValueError: Error if mu and sigma cannot be calculated. @@ -88,8 +94,13 @@ def calculate_path_statistics( img_names = sorted( [name for ext in IMAGE_EXTS for name in posfix_path.glob(f"*.{ext}")] ) + transfs_list = [] + if resize is not None: + print(f"Resizing images to {(resize, resize)} resolution") + transfs_list.append(tv.Resize((resize, resize))) + transfs_list.append(tv.ToTensor()) dataloader = th.utils.data.DataLoader( - ImagePathDataset(img_names, transforms=tv.ToTensor()), + ImagePathDataset(img_names, transforms=tv.Compose(transfs_list)), batch_size=batch_size, shuffle=False, drop_last=False, @@ -122,7 +133,12 @@ def _compute_avg_frechet_distance(mu1, mu2, sigma1, sigma2): def compute_fwd( - paths: List[str], wavelet: str, max_level: int, log_scale: bool, batch_size: int + paths: List[str], + wavelet: str, + max_level: int, + log_scale: bool, + batch_size: int, + resize: Union[int, None], ) -> float: """Compute Frechet Wavelet Distance. @@ -132,6 +148,7 @@ def compute_fwd( max_level (int): Decomposition level. log_scale (bool): Apply log scale. batch_size (int): Batch size for packet decomposition. + resize (Union[int, None]): Optional resize option. Raises: RuntimeError: Error if path doesn't exist. @@ -145,11 +162,11 @@ def compute_fwd( print(f"Computing stats for path: {paths[0]}") mu_1, sigma_1 = calculate_path_statistics( - paths[0], wavelet, max_level, log_scale, batch_size + paths[0], wavelet, max_level, log_scale, batch_size, resize ) print(f"Computing stats for path: {paths[1]}") mu_2, sigma_2 = calculate_path_statistics( - paths[1], wavelet, max_level, log_scale, batch_size + paths[1], wavelet, max_level, log_scale, batch_size, resize ) print("Computing Frechet distances for each packet.") @@ -157,7 +174,12 @@ def compute_fwd( def _save_packets( - paths: List[str], wavelet: str, max_level: int, log_scale: bool, batch_size: int + paths: List[str], + wavelet: str, + max_level: int, + log_scale: bool, + batch_size: int, + resize: Union[int, None], ) -> None: """Save packets. @@ -167,6 +189,7 @@ def _save_packets( max_level (int): Decomposition level. log_scale (bool): Apply log scale. batch_size (int): Batch size for packet decomposition. + resize (Union[int, None]): Optional resize option. Raises: RuntimeError: Error if input path is invalid. @@ -180,7 +203,7 @@ def _save_packets( print(f"Computing stats for path: {paths[0]}") mu_1, sigma_1 = calculate_path_statistics( - paths[0], wavelet, max_level, log_scale, batch_size + paths[0], wavelet, max_level, log_scale, batch_size, resize ) np.savez_compressed(paths[1], mu=mu_1, sigma=sigma_1) @@ -205,12 +228,22 @@ def main(): th.use_deterministic_algorithms(True) if args.save_packets: _save_packets( - args.path, args.wavelet, args.max_level, args.log_scale, args.batch_size + args.path, + args.wavelet, + args.max_level, + args.log_scale, + args.batch_size, + args.resize, ) return fwd = compute_fwd( - args.path, args.wavelet, args.max_level, args.log_scale, args.batch_size + args.path, + args.wavelet, + args.max_level, + args.log_scale, + args.batch_size, + args.resize, ) print(f"FWD: {fwd}") diff --git a/src/pytorchfwd/utils.py b/src/pytorchfwd/utils.py index 61ecb70..854d156 100644 --- a/src/pytorchfwd/utils.py +++ b/src/pytorchfwd/utils.py @@ -30,6 +30,12 @@ def _parse_args(): parser.add_argument( "--log_scale", action="store_true", help="Use log scaling for wavelets." ) + parser.add_argument( + "--resize", + type=int, + default=None, + help="Resize the images to specified resolution.", + ) parser.add_argument( "--deterministic", action="store_true", diff --git a/src/pytorchfwd/version.py b/src/pytorchfwd/version.py index 1d8c9b0..0ee4b45 100644 --- a/src/pytorchfwd/version.py +++ b/src/pytorchfwd/version.py @@ -1,3 +1,3 @@ """Record the package version.""" -VERSION = "1.0.1-dev" +VERSION = "1.0.1"