diff --git a/src/pytorch_fid/__init__.py b/src/pytorch_fid/__init__.py index 7fd229a..c8c347e 100644 --- a/src/pytorch_fid/__init__.py +++ b/src/pytorch_fid/__init__.py @@ -1 +1,4 @@ +from pytorch_fid.fid_score import FrechetInceptionDistance + __version__ = '0.2.0' +__all__ = [FrechetInceptionDistance] diff --git a/src/pytorch_fid/fid_score.py b/src/pytorch_fid/fid_score.py index 1459043..4cf6a88 100755 --- a/src/pytorch_fid/fid_score.py +++ b/src/pytorch_fid/fid_score.py @@ -31,17 +31,18 @@ See the License for the specific language governing permissions and limitations under the License. """ -import os -import pathlib from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser from multiprocessing import cpu_count +from pathlib import Path import numpy as np import torch -import torchvision.transforms as TF from PIL import Image from scipy import linalg from torch.nn.functional import adaptive_avg_pool2d +from torch.utils import data +from torch.utils.data import DataLoader +from torchvision import transforms try: from tqdm import tqdm @@ -52,223 +53,251 @@ def tqdm(x): from pytorch_fid.inception import InceptionV3 -parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) -parser.add_argument('--batch-size', type=int, default=50, - help='Batch size to use') -parser.add_argument('--device', type=str, default=None, - help='Device to use. Like cuda, cuda:0 or cpu') -parser.add_argument('--dims', type=int, default=2048, - choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), - help=('Dimensionality of Inception features to use. ' - 'By default, uses pool3 features')) -parser.add_argument('path', type=str, nargs=2, - help=('Paths to the generated images or ' - 'to .npz statistic files')) +IMAGE_EXTENSIONS = ( + 'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 'tif', 'tiff', 'webp' +) -IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', - 'tif', 'tiff', 'webp'} - -class ImagePathDataset(torch.utils.data.Dataset): - def __init__(self, files, transforms=None): - self.files = files - self.transforms = transforms +class ImagePathDataset(data.Dataset): + def __init__(self, root, transform=None, ext=IMAGE_EXTENSIONS): + super().__init__() + self.files = [f for f in Path(root).iterdir() + if f.is_file() and f.suffix.lower()[1:] in ext] + self.transform = transform def __len__(self): return len(self.files) - def __getitem__(self, i): - path = self.files[i] - img = Image.open(path).convert('RGB') - if self.transforms is not None: - img = self.transforms(img) + def __getitem__(self, idx): + filename = self.files[idx] + img = Image.open(filename).convert('RGB') + if self.transform: + img = self.transform(img) return img -def get_activations(files, model, batch_size=50, dims=2048, device='cpu'): - """Calculates the activations of the pool_3 layer for all images. +class FrechetInceptionDistance: + def __init__(self, + model, + dims=2048, + batch_size=32, + num_workers=None, + progressbar=False): + self.model = model + self.dims = dims + self.batch_size = batch_size + self.num_workers = cpu_count() if num_workers is None else num_workers + self.progressbar = progressbar - Params: - -- files : List of image files paths - -- model : Instance of inception model - -- batch_size : Batch size of images for the model to process at once. - Make sure that the number of samples is a multiple of - the batch size, otherwise some samples are ignored. This - behavior is retained to match the original FID score - implementation. - -- dims : Dimensionality of features returned by Inception - -- device : Device to run calculations + @staticmethod + def get_inception_model(dims=2048): + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + return InceptionV3([block_idx]) + + def _get_device(self): + return next(self.model.parameters()).device + + def get_activations(self, batches): + """ + Calculates the activations of the pool_3 layer for all images. + + Args: + batches: Iterator returning image batches in pytorch tensor format. + + Returns: + A numpy array of dimension (num images, dims) containing + feature activations for given images. + """ + + self.model.eval() + + device = self._get_device() + + activations = [] + + if self.progressbar: + batches = tqdm(batches) + + for batch in batches: + batch = batch.to(device) + + pred = self.model(batch)[0] + + # If model output is not scalar, apply global spatial average + # pooling. + # This happens if you choose a dimensionality not equal 2048. + if pred.size(2) != 1 or pred.size(3) != 1: + pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) + + activations.append( + pred.cpu().data.numpy().reshape(pred.size(0), -1) + ) + + return np.concatenate(activations) + + @staticmethod + def calculate_activation_statistics(activations): + """ + Calculates statistics used for FID by given feature activations. + + Args: + activations: Numpy array of dimension (num images, dims) + containing feature activations. + + Returns: + Mean and covariance matrix over the given activations. + """ + mu = np.mean(activations, axis=0) + sigma = np.cov(activations, rowvar=False) + return mu, sigma + + def get_activation_statistics(self, batches): + activations = self.get_activations(batches) + mu, sigma = self.calculate_activation_statistics(activations) + return mu, sigma + + @staticmethod + def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """ + Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate + Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). - Returns: - -- A numpy array of dimension (num images, dims) that contains the - activations of the given tensor when feeding inception with the - query tensor. - """ - model.eval() + Stable version by Dougal J. Sutherland. - if batch_size > len(files): - print(('Warning: batch size is bigger than the data size. ' - 'Setting batch size to data size')) - batch_size = len(files) - - dataset = ImagePathDataset(files, transforms=TF.ToTensor()) - dataloader = torch.utils.data.DataLoader(dataset, - batch_size=batch_size, - shuffle=False, - drop_last=False, - num_workers=cpu_count()) - - pred_arr = np.empty((len(files), dims)) - - start_idx = 0 - - for batch in tqdm(dataloader): - batch = batch.to(device) - - with torch.no_grad(): - pred = model(batch)[0] - - # If model output is not scalar, apply global spatial average pooling. - # This happens if you choose a dimensionality not equal 2048. - if pred.size(2) != 1 or pred.size(3) != 1: - pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) - - pred = pred.squeeze(3).squeeze(2).cpu().numpy() - - pred_arr[start_idx:start_idx + pred.shape[0]] = pred - - start_idx = start_idx + pred.shape[0] - - return pred_arr - - -def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): - """Numpy implementation of the Frechet Distance. - The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) - and X_2 ~ N(mu_2, C_2) is - d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). - - Stable version by Dougal J. Sutherland. - - Params: - -- mu1 : Numpy array containing the activations of a layer of the - inception net (like returned by the function 'get_predictions') - for generated samples. - -- mu2 : The sample mean over activations, precalculated on an - representative data set. - -- sigma1: The covariance matrix over activations for generated samples. - -- sigma2: The covariance matrix over activations, precalculated on an - representative data set. - - Returns: - -- : The Frechet Distance. - """ - - mu1 = np.atleast_1d(mu1) - mu2 = np.atleast_1d(mu2) - - sigma1 = np.atleast_2d(sigma1) - sigma2 = np.atleast_2d(sigma2) - - assert mu1.shape == mu2.shape, \ - 'Training and test mean vectors have different lengths' - assert sigma1.shape == sigma2.shape, \ - 'Training and test covariances have different dimensions' - - diff = mu1 - mu2 - - # Product might be almost singular - covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) - if not np.isfinite(covmean).all(): - msg = ('fid calculation produces singular product; ' - 'adding %s to diagonal of cov estimates') % eps - print(msg) - offset = np.eye(sigma1.shape[0]) * eps - covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) - - # Numerical error might give slight imaginary component - if np.iscomplexobj(covmean): - if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): - m = np.max(np.abs(covmean.imag)) - raise ValueError('Imaginary component {}'.format(m)) - covmean = covmean.real - - tr_covmean = np.trace(covmean) - - return (diff.dot(diff) + np.trace(sigma1) - + np.trace(sigma2) - 2 * tr_covmean) - - -def calculate_activation_statistics(files, model, batch_size=50, dims=2048, - device='cpu'): - """Calculation of the statistics used by the FID. - Params: - -- files : List of image files paths - -- model : Instance of inception model - -- batch_size : The images numpy array is split into batches with - batch size batch_size. A reasonable batch size - depends on the hardware. - -- dims : Dimensionality of features returned by Inception - -- device : Device to run calculations - - Returns: - -- mu : The mean over samples of the activations of the pool_3 layer of - the inception model. - -- sigma : The covariance matrix of the activations of the pool_3 layer of - the inception model. - """ - act = get_activations(files, model, batch_size, dims, device) - mu = np.mean(act, axis=0) - sigma = np.cov(act, rowvar=False) - return mu, sigma - - -def compute_statistics_of_path(path, model, batch_size, dims, device): - if path.endswith('.npz'): - with np.load(path) as f: - m, s = f['mu'][:], f['sigma'][:] - else: - path = pathlib.Path(path) - files = sorted([file for ext in IMAGE_EXTENSIONS - for file in path.glob('*.{}'.format(ext))]) - m, s = calculate_activation_statistics(files, model, batch_size, - dims, device) + Args: + mu1: Mean over first set of activations. + sigma1: Covariance matrix over first set of activations. + mu1: Mean over second set of activations. + sigma1: Covariance matrix over second set of activations. + + Returns: + The Frechet Inception Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real - return m, s + tr_covmean = np.trace(covmean) + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) -def calculate_fid_given_paths(paths, batch_size, device, dims): - """Calculates the FID of two paths""" - for p in paths: - if not os.path.exists(p): - raise RuntimeError('Invalid path: %s' % p) + def get_batches_from_image_folder(self, path): + transformations = [ + transforms.Resize((299, 299)), + transforms.ToTensor() + ] - block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + images = ImagePathDataset( + path, transform=transforms.Compose(transformations) + ) - model = InceptionV3([block_idx]).to(device) + if not len(images): + raise AssertionError(f'No images found in path {path}') - m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, - dims, device) - m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, - dims, device) - fid_value = calculate_frechet_distance(m1, s1, m2, s2) + batches = DataLoader( + images, + batch_size=self.batch_size, + shuffle=False, + drop_last=False, + num_workers=self.num_workers, + ) - return fid_value + return batches + + def get_statistics_for_path(self, path, cache=True): + path = Path(path) + if path.is_file() and path.suffix in ['.np', '.npz']: + cached = path + cache = True + else: + cached = path / f'inception_statistics_{self.dims}.npz' + + if cached.is_file() and cache: + with np.load(cached) as fp: + m, s = fp['mu'][:], fp['sigma'][:] + return m, s + + batches = self.get_batches_from_image_folder(path) + + activations = self.get_activations(batches) + m, s = self.calculate_activation_statistics(activations) + + if cache: + np.savez(cached, mu=m, sigma=s) + + return m, s + + def calculate_fid_given_paths(self, path1, path2, cache=True): + m1, s1 = self.get_statistics_for_path(path1, cache=cache) + m2, s2 = self.get_statistics_for_path(path2, cache=cache) + + fid_score = self.calculate_frechet_distance(m1, s1, m2, s2) + + return fid_score def main(): + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + parser.add_argument('--batch-size', type=int, default=64, + help='Batch size to use') + parser.add_argument('--dims', type=int, default=2048, + choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), + help=('Dimensionality of Inception features to use. ' + 'By default, uses pool3 features')) + parser.add_argument('--device', + help='Device to use. Defaults to CPU if not provided.') + parser.add_argument('path', type=str, nargs=2, + help=('Paths to the generated images or ' + 'to .npz statistic files')) + parser.add_argument('--cache', action='store_true', + help='Whether to look for cached statistics or cache ' + 'computed statistics in the given image folders.') + parser.add_argument('--num-workers', type=int, default=4) args = parser.parse_args() if args.device is None: - device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: device = torch.device(args.device) - fid_value = calculate_fid_given_paths(args.path, - args.batch_size, - device, - args.dims) - print('FID: ', fid_value) + model = FrechetInceptionDistance.get_inception_model(args.dims).to(device) + fid = FrechetInceptionDistance( + model, args.dims, args.batch_size, args.num_workers, progressbar=True + ) + fid_score = fid.calculate_fid_given_paths( + args.path[0], args.path[1], cache=args.cache + ) + + print('FID:', fid_score) if __name__ == '__main__': diff --git a/tests/test_fid_score.py b/tests/test_fid_score.py index 2ce63f3..9bc0c28 100644 --- a/tests/test_fid_score.py +++ b/tests/test_fid_score.py @@ -3,7 +3,7 @@ import torch from PIL import Image -from pytorch_fid import fid_score, inception +from pytorch_fid import FrechetInceptionDistance, fid_score, inception @pytest.fixture @@ -11,12 +11,12 @@ def device(): return torch.device('cpu') -def test_calculate_fid_given_statistics(mocker, tmp_path, device): +def test_calculate_fid_given_statistics(mocker, tmp_path): dim = 2048 m1, m2 = np.zeros((dim,)), np.ones((dim,)) sigma = np.eye(dim) - def dummy_statistics(path, model, batch_size, dims, device): + def dummy_statistics(path, cache): if path.endswith('1'): return m1, sigma elif path.endswith('2'): @@ -24,8 +24,10 @@ def dummy_statistics(path, model, batch_size, dims, device): else: raise ValueError - mocker.patch('pytorch_fid.fid_score.compute_statistics_of_path', - side_effect=dummy_statistics) + mocker.patch( + 'pytorch_fid.FrechetInceptionDistance.get_statistics_for_path', + side_effect=dummy_statistics + ) dir_names = ['1', '2'] paths = [] @@ -34,10 +36,8 @@ def dummy_statistics(path, model, batch_size, dims, device): path.mkdir() paths.append(str(path)) - fid_value = fid_score.calculate_fid_given_paths(paths, - batch_size=dim, - device=device, - dims=dim) + fid = FrechetInceptionDistance(None) + fid_value = fid.calculate_fid_given_paths(*paths) # Given equal covariance, FID is just the squared norm of difference assert fid_value == np.sum((m1 - m2)**2) @@ -47,6 +47,11 @@ def test_compute_statistics_of_path(mocker, tmp_path, device): model = mocker.MagicMock(inception.InceptionV3)() model.side_effect = lambda inp: [inp.mean(dim=(2, 3), keepdim=True)] + mocker.patch( + 'pytorch_fid.FrechetInceptionDistance._get_device', + return_value=device + ) + size = (4, 4, 3) arrays = [np.zeros(size), np.ones(size) * 0.5, np.ones(size)] images = [(arr * 255).astype(np.uint8) for arr in arrays] @@ -56,10 +61,8 @@ def test_compute_statistics_of_path(mocker, tmp_path, device): paths.append(str(tmp_path / '{}.png'.format(idx))) Image.fromarray(image, mode='RGB').save(paths[-1]) - stats = fid_score.compute_statistics_of_path(str(tmp_path), model, - batch_size=len(images), - dims=3, - device=device) + fid = FrechetInceptionDistance(model, dims=3, batch_size=len(images)) + stats = fid.get_statistics_for_path(str(tmp_path), model) assert np.allclose(stats[0], np.ones((3,)) * 0.5, atol=1e-3) assert np.allclose(stats[1], np.ones((3, 3)) * 0.25) @@ -75,10 +78,8 @@ def test_compute_statistics_of_path_from_file(mocker, tmp_path, device): with path.open('wb') as f: np.savez(f, mu=mu, sigma=sigma) - stats = fid_score.compute_statistics_of_path(str(path), model, - batch_size=1, - dims=5, - device=device) + fid = FrechetInceptionDistance(model, dims=5) + stats = fid.get_statistics_for_path(str(path), model) assert np.allclose(stats[0], mu) assert np.allclose(stats[1], sigma) @@ -93,7 +94,7 @@ def test_image_types(tmp_path): paths.append(str(tmp_path / 'img.{}'.format(ext))) in_image.save(paths[-1]) - dataset = fid_score.ImagePathDataset(paths) + dataset = fid_score.ImagePathDataset(tmp_path) for img in dataset: assert np.allclose(np.array(img), in_arr)