Skip to content

Commit 24588e8

Browse files
authored
Merge pull request #227 from Adversarian/main
Robust batched FID calculation
2 parents ac933e4 + 9f77c49 commit 24588e8

File tree

2 files changed

+151
-48
lines changed

2 files changed

+151
-48
lines changed

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 45 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@
2626
from accelerate import Accelerator
2727

2828
import numpy as np
29-
from pytorch_fid.inception import InceptionV3
30-
from pytorch_fid.fid_score import calculate_frechet_distance
29+
from denoising_diffusion_pytorch.fid_evaluation import FIDEvaluation
3130

3231
from denoising_diffusion_pytorch.version import __version__
3332

@@ -610,7 +609,7 @@ def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
610609
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
611610
return model_mean, posterior_variance, posterior_log_variance, x_start
612611

613-
@torch.no_grad()
612+
@torch.inference_mode()
614613
def p_sample(self, x, t: int, x_self_cond = None):
615614
b, *_, device = *x.shape, self.device
616615
batched_times = torch.full((b,), t, device = device, dtype = torch.long)
@@ -619,7 +618,7 @@ def p_sample(self, x, t: int, x_self_cond = None):
619618
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
620619
return pred_img, x_start
621620

622-
@torch.no_grad()
621+
@torch.inference_mode()
623622
def p_sample_loop(self, shape, return_all_timesteps = False):
624623
batch, device = shape[0], self.device
625624

@@ -638,7 +637,7 @@ def p_sample_loop(self, shape, return_all_timesteps = False):
638637
ret = self.unnormalize(ret)
639638
return ret
640639

641-
@torch.no_grad()
640+
@torch.inference_mode()
642641
def ddim_sample(self, shape, return_all_timesteps = False):
643642
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
644643

@@ -680,13 +679,13 @@ def ddim_sample(self, shape, return_all_timesteps = False):
680679
ret = self.unnormalize(ret)
681680
return ret
682681

683-
@torch.no_grad()
682+
@torch.inference_mode()
684683
def sample(self, batch_size = 16, return_all_timesteps = False):
685684
image_size, channels = self.image_size, self.channels
686685
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
687686
return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps)
688687

689-
@torch.no_grad()
688+
@torch.inference_mode()
690689
def interpolate(self, x1, x2, t = None, lam = 0.5):
691690
b, *_, device = *x1.shape, x1.device
692691
t = default(t, self.num_timesteps - 1)
@@ -738,7 +737,7 @@ def p_losses(self, x_start, t, noise = None, offset_noise_strength = None):
738737

739738
x_self_cond = None
740739
if self.self_condition and random() < 0.5:
741-
with torch.no_grad():
740+
with torch.inference_mode():
742741
x_self_cond = self.model_predictions(x, t).pred_x_start
743742
x_self_cond.detach_()
744743

@@ -829,7 +828,9 @@ def __init__(
829828
convert_image_to = None,
830829
calculate_fid = True,
831830
inception_block_idx = 2048,
832-
max_grad_norm = 1.
831+
max_grad_norm = 1.,
832+
num_fid_samples = 50000,
833+
save_best_and_latest_only = False
833834
):
834835
super().__init__()
835836

@@ -845,21 +846,15 @@ def __init__(
845846
self.model = diffusion_model
846847
self.channels = diffusion_model.channels
847848

848-
# InceptionV3 for fid-score computation
849-
850-
self.inception_v3 = None
851-
852-
if calculate_fid:
853-
assert inception_block_idx in InceptionV3.BLOCK_INDEX_BY_DIM
854-
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[inception_block_idx]
855-
self.inception_v3 = InceptionV3([block_idx])
856-
self.inception_v3.to(self.device)
857-
858849
# sampling and training hyperparameters
859850

860851
assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
861852
self.num_samples = num_samples
862853
self.save_and_sample_every = save_and_sample_every
854+
if save_best_and_latest_only:
855+
assert calculate_fid, "`calculate_fid` must be True to provide a means for model evaluation for `save_best_and_latest_only`."
856+
self.best_fid = 1e10 # infinite
857+
self.save_best_and_latest_only = save_best_and_latest_only
863858

864859
self.batch_size = train_batch_size
865860
self.gradient_accumulate_every = gradient_accumulate_every
@@ -898,6 +893,27 @@ def __init__(
898893

899894
self.model, self.opt = self.accelerator.prepare(self.model, self.opt)
900895

896+
# FID-score computation
897+
898+
if calculate_fid:
899+
self.calculate_fid = True
900+
if not self.model.is_ddim_sampling:
901+
self.accelerator.print(
902+
"WARNING: Robust FID computation requires a lot of generated samples and can therefore be very time consuming."\
903+
"Consider using DDIM sampling to save time."
904+
)
905+
self.fid_scorer = FIDEvaluation(
906+
batch_size=self.batch_size,
907+
dl=self.dl,
908+
sampler=self.ema.ema_model,
909+
channels=self.channels,
910+
accelerator=self.accelerator,
911+
stats_dir=results_folder,
912+
device=self.device,
913+
num_fid_samples=num_fid_samples,
914+
inception_block_idx=inception_block_idx
915+
)
916+
901917
@property
902918
def device(self):
903919
return self.accelerator.device
@@ -937,31 +953,6 @@ def load(self, milestone):
937953
if exists(self.accelerator.scaler) and exists(data['scaler']):
938954
self.accelerator.scaler.load_state_dict(data['scaler'])
939955

940-
@torch.no_grad()
941-
def calculate_activation_statistics(self, samples):
942-
assert exists(self.inception_v3)
943-
944-
features = self.inception_v3(samples)[0]
945-
features = rearrange(features, '... 1 1 -> ...').cpu().numpy()
946-
947-
mu = np.mean(features, axis = 0)
948-
sigma = np.cov(features, rowvar = False)
949-
return mu, sigma
950-
951-
def fid_score(self, real_samples, fake_samples):
952-
953-
if self.channels == 1:
954-
real_samples, fake_samples = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (real_samples, fake_samples))
955-
956-
min_batch = min(real_samples.shape[0], fake_samples.shape[0])
957-
real_samples, fake_samples = map(lambda t: t[:min_batch], (real_samples, fake_samples))
958-
959-
m1, s1 = self.calculate_activation_statistics(real_samples)
960-
m2, s2 = self.calculate_activation_statistics(fake_samples)
961-
962-
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
963-
return fid_value
964-
965956
def train(self):
966957
accelerator = self.accelerator
967958
device = accelerator.device
@@ -999,21 +990,27 @@ def train(self):
999990
if self.step != 0 and self.step % self.save_and_sample_every == 0:
1000991
self.ema.ema_model.eval()
1001992

1002-
with torch.no_grad():
993+
with torch.inference_mode():
1003994
milestone = self.step // self.save_and_sample_every
1004995
batches = num_to_groups(self.num_samples, self.batch_size)
1005996
all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))
1006997

1007998
all_images = torch.cat(all_images_list, dim = 0)
1008999

10091000
utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = int(math.sqrt(self.num_samples)))
1010-
self.save(milestone)
10111001

10121002
# whether to calculate fid
10131003

1014-
if exists(self.inception_v3):
1015-
fid_score = self.fid_score(real_samples = data, fake_samples = all_images)
1004+
if self.calculate_fid:
1005+
fid_score = self.fid_scorer.fid_score()
10161006
accelerator.print(f'fid_score: {fid_score}')
1007+
if self.save_best_and_latest_only:
1008+
if self.best_fid > fid_score:
1009+
self.best_fid = fid_score
1010+
self.save("best")
1011+
self.save("latest")
1012+
else:
1013+
self.save(milestone)
10171014

10181015
pbar.update(1)
10191016

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import math
2+
import os
3+
4+
import numpy as np
5+
import torch
6+
from einops import rearrange, repeat
7+
from pytorch_fid.fid_score import calculate_frechet_distance
8+
from pytorch_fid.inception import InceptionV3
9+
from torch.nn.functional import adaptive_avg_pool2d
10+
from tqdm.auto import tqdm
11+
12+
13+
def num_to_groups(num, divisor):
14+
groups = num // divisor
15+
remainder = num % divisor
16+
arr = [divisor] * groups
17+
if remainder > 0:
18+
arr.append(remainder)
19+
return arr
20+
21+
22+
class FIDEvaluation:
23+
def __init__(
24+
self,
25+
batch_size,
26+
dl,
27+
sampler,
28+
channels=3,
29+
accelerator=None,
30+
stats_dir="./results",
31+
device="cuda",
32+
num_fid_samples=50000,
33+
inception_block_idx=2048,
34+
):
35+
self.batch_size = batch_size
36+
self.n_samples = num_fid_samples
37+
self.device = device
38+
self.channels = channels
39+
self.dl = dl
40+
self.sampler = sampler
41+
self.stats_dir = stats_dir
42+
self.print_fn = print if accelerator is None else accelerator.print
43+
assert inception_block_idx in InceptionV3.BLOCK_INDEX_BY_DIM
44+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[inception_block_idx]
45+
self.inception_v3 = InceptionV3([block_idx]).to(device)
46+
self.dataset_stats_loaded = False
47+
48+
def calculate_inception_features(self, samples):
49+
if self.channels == 1:
50+
samples = repeat(samples, "b 1 ... -> b c ...", c=3)
51+
features = self.inception_v3(samples)[0]
52+
if features.size(2) != 1 or features.size(3) != 1:
53+
features = adaptive_avg_pool2d(features, output_size=(1, 1))
54+
features = rearrange(features, "... 1 1 -> ...")
55+
return features
56+
57+
def load_or_precalc_dataset_stats(self):
58+
path = os.path.join(self.stats_dir, "dataset_stats")
59+
try:
60+
ckpt = np.load(path + ".npz")
61+
self.m2, self.s2 = ckpt["m2"], ckpt["s2"]
62+
self.print_fn("Dataset stats loaded from disk.")
63+
ckpt.close()
64+
except OSError:
65+
num_batches = int(math.ceil(self.n_samples / self.batch_size))
66+
stacked_real_features = []
67+
self.print_fn(
68+
f"Stacking Inception features for {self.n_samples} samples from the real dataset."
69+
)
70+
for _ in tqdm(range(num_batches)):
71+
try:
72+
real_samples = next(self.dl)
73+
except StopIteration:
74+
break
75+
real_samples = real_samples.to(self.device)
76+
real_features = self.calculate_inception_features(real_samples)
77+
stacked_real_features.append(real_features)
78+
stacked_real_features = (
79+
torch.cat(stacked_real_features, dim=0).cpu().numpy()
80+
)
81+
m2 = np.mean(stacked_real_features, axis=0)
82+
s2 = np.cov(stacked_real_features, rowvar=False)
83+
np.savez_compressed(path, m2=m2, s2=s2)
84+
self.print_fn(f"Dataset stats cached to {path}.npz for future use.")
85+
self.m2, self.s2 = m2, s2
86+
self.dataset_stats_loaded = True
87+
88+
@torch.inference_mode()
89+
def fid_score(self):
90+
if not self.dataset_stats_loaded:
91+
self.load_or_precalc_dataset_stats()
92+
self.sampler.eval()
93+
batches = num_to_groups(self.n_samples, self.batch_size)
94+
stacked_fake_features = []
95+
self.print_fn(
96+
f"Stacking Inception features for {self.n_samples} generated samples."
97+
)
98+
for batch in tqdm(batches):
99+
fake_samples = self.sampler.sample(batch_size=batch)
100+
fake_features = self.calculate_inception_features(fake_samples)
101+
stacked_fake_features.append(fake_features)
102+
stacked_fake_features = torch.cat(stacked_fake_features, dim=0).cpu().numpy()
103+
m1 = np.mean(stacked_fake_features, axis=0)
104+
s1 = np.cov(stacked_fake_features, rowvar=False)
105+
106+
return calculate_frechet_distance(m1, s1, self.m2, self.s2)

0 commit comments

Comments
 (0)