22
33import os
44import pathlib
5- from typing import List , Tuple
5+ from typing import List , Tuple , Union
66
77import numpy as np
88import torch as th
@@ -61,7 +61,12 @@ def gpu_cov(tensor_):
6161
6262
6363def calculate_path_statistics (
64- path : str , wavelet : str , max_level : int , log_scale : bool , batch_size : int
64+ path : str ,
65+ wavelet : str ,
66+ max_level : int ,
67+ log_scale : bool ,
68+ batch_size : int ,
69+ resize : Union [int , None ],
6570) -> Tuple [np .ndarray , ...]:
6671 """Compute mean and sigma for given path.
6772
@@ -71,6 +76,7 @@ def calculate_path_statistics(
7176 max_level (int): Decomposition level.
7277 log_scale (bool): Apply log scale.
7378 batch_size (int): Batch size for packet decomposition.
79+ resize (Union[int, None]): Optional resize option.
7480
7581 Raises:
7682 ValueError: Error if mu and sigma cannot be calculated.
@@ -88,8 +94,13 @@ def calculate_path_statistics(
8894 img_names = sorted (
8995 [name for ext in IMAGE_EXTS for name in posfix_path .glob (f"*.{ ext } " )]
9096 )
97+ transfs_list = []
98+ if resize is not None :
99+ print (f"Resizing images to { (resize , resize )} resolution" )
100+ transfs_list .append (tv .Resize ((resize , resize )))
101+ transfs_list .append (tv .ToTensor ())
91102 dataloader = th .utils .data .DataLoader (
92- ImagePathDataset (img_names , transforms = tv .ToTensor ( )),
103+ ImagePathDataset (img_names , transforms = tv .Compose ( transfs_list )),
93104 batch_size = batch_size ,
94105 shuffle = False ,
95106 drop_last = False ,
@@ -122,7 +133,12 @@ def _compute_avg_frechet_distance(mu1, mu2, sigma1, sigma2):
122133
123134
124135def compute_fwd (
125- paths : List [str ], wavelet : str , max_level : int , log_scale : bool , batch_size : int
136+ paths : List [str ],
137+ wavelet : str ,
138+ max_level : int ,
139+ log_scale : bool ,
140+ batch_size : int ,
141+ resize : Union [int , None ],
126142) -> float :
127143 """Compute Frechet Wavelet Distance.
128144
@@ -132,6 +148,7 @@ def compute_fwd(
132148 max_level (int): Decomposition level.
133149 log_scale (bool): Apply log scale.
134150 batch_size (int): Batch size for packet decomposition.
151+ resize (Union[int, None]): Optional resize option.
135152
136153 Raises:
137154 RuntimeError: Error if path doesn't exist.
@@ -145,19 +162,24 @@ def compute_fwd(
145162
146163 print (f"Computing stats for path: { paths [0 ]} " )
147164 mu_1 , sigma_1 = calculate_path_statistics (
148- paths [0 ], wavelet , max_level , log_scale , batch_size
165+ paths [0 ], wavelet , max_level , log_scale , batch_size , resize
149166 )
150167 print (f"Computing stats for path: { paths [1 ]} " )
151168 mu_2 , sigma_2 = calculate_path_statistics (
152- paths [1 ], wavelet , max_level , log_scale , batch_size
169+ paths [1 ], wavelet , max_level , log_scale , batch_size , resize
153170 )
154171
155172 print ("Computing Frechet distances for each packet." )
156173 return _compute_avg_frechet_distance (mu_1 , mu_2 , sigma_1 , sigma_2 )
157174
158175
159176def _save_packets (
160- paths : List [str ], wavelet : str , max_level : int , log_scale : bool , batch_size : int
177+ paths : List [str ],
178+ wavelet : str ,
179+ max_level : int ,
180+ log_scale : bool ,
181+ batch_size : int ,
182+ resize : Union [int , None ],
161183) -> None :
162184 """Save packets.
163185
@@ -167,6 +189,7 @@ def _save_packets(
167189 max_level (int): Decomposition level.
168190 log_scale (bool): Apply log scale.
169191 batch_size (int): Batch size for packet decomposition.
192+ resize (Union[int, None]): Optional resize option.
170193
171194 Raises:
172195 RuntimeError: Error if input path is invalid.
@@ -180,7 +203,7 @@ def _save_packets(
180203
181204 print (f"Computing stats for path: { paths [0 ]} " )
182205 mu_1 , sigma_1 = calculate_path_statistics (
183- paths [0 ], wavelet , max_level , log_scale , batch_size
206+ paths [0 ], wavelet , max_level , log_scale , batch_size , resize
184207 )
185208 np .savez_compressed (paths [1 ], mu = mu_1 , sigma = sigma_1 )
186209
@@ -205,12 +228,22 @@ def main():
205228 th .use_deterministic_algorithms (True )
206229 if args .save_packets :
207230 _save_packets (
208- args .path , args .wavelet , args .max_level , args .log_scale , args .batch_size
231+ args .path ,
232+ args .wavelet ,
233+ args .max_level ,
234+ args .log_scale ,
235+ args .batch_size ,
236+ args .resize ,
209237 )
210238 return
211239
212240 fwd = compute_fwd (
213- args .path , args .wavelet , args .max_level , args .log_scale , args .batch_size
241+ args .path ,
242+ args .wavelet ,
243+ args .max_level ,
244+ args .log_scale ,
245+ args .batch_size ,
246+ args .resize ,
214247 )
215248 print (f"FWD: { fwd } " )
216249
0 commit comments