2626from accelerate import Accelerator
2727
2828import numpy as np
29- from pytorch_fid .inception import InceptionV3
30- from pytorch_fid .fid_score import calculate_frechet_distance
29+ from denoising_diffusion_pytorch .fid_evaluation import FIDEvaluation
3130
3231from denoising_diffusion_pytorch .version import __version__
3332
@@ -610,7 +609,7 @@ def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
610609 model_mean , posterior_variance , posterior_log_variance = self .q_posterior (x_start = x_start , x_t = x , t = t )
611610 return model_mean , posterior_variance , posterior_log_variance , x_start
612611
613- @torch .no_grad ()
612+ @torch .inference_mode ()
614613 def p_sample (self , x , t : int , x_self_cond = None ):
615614 b , * _ , device = * x .shape , self .device
616615 batched_times = torch .full ((b ,), t , device = device , dtype = torch .long )
@@ -619,7 +618,7 @@ def p_sample(self, x, t: int, x_self_cond = None):
619618 pred_img = model_mean + (0.5 * model_log_variance ).exp () * noise
620619 return pred_img , x_start
621620
622- @torch .no_grad ()
621+ @torch .inference_mode ()
623622 def p_sample_loop (self , shape , return_all_timesteps = False ):
624623 batch , device = shape [0 ], self .device
625624
@@ -638,7 +637,7 @@ def p_sample_loop(self, shape, return_all_timesteps = False):
638637 ret = self .unnormalize (ret )
639638 return ret
640639
641- @torch .no_grad ()
640+ @torch .inference_mode ()
642641 def ddim_sample (self , shape , return_all_timesteps = False ):
643642 batch , device , total_timesteps , sampling_timesteps , eta , objective = shape [0 ], self .device , self .num_timesteps , self .sampling_timesteps , self .ddim_sampling_eta , self .objective
644643
@@ -680,13 +679,13 @@ def ddim_sample(self, shape, return_all_timesteps = False):
680679 ret = self .unnormalize (ret )
681680 return ret
682681
683- @torch .no_grad ()
682+ @torch .inference_mode ()
684683 def sample (self , batch_size = 16 , return_all_timesteps = False ):
685684 image_size , channels = self .image_size , self .channels
686685 sample_fn = self .p_sample_loop if not self .is_ddim_sampling else self .ddim_sample
687686 return sample_fn ((batch_size , channels , image_size , image_size ), return_all_timesteps = return_all_timesteps )
688687
689- @torch .no_grad ()
688+ @torch .inference_mode ()
690689 def interpolate (self , x1 , x2 , t = None , lam = 0.5 ):
691690 b , * _ , device = * x1 .shape , x1 .device
692691 t = default (t , self .num_timesteps - 1 )
@@ -738,7 +737,7 @@ def p_losses(self, x_start, t, noise = None, offset_noise_strength = None):
738737
739738 x_self_cond = None
740739 if self .self_condition and random () < 0.5 :
741- with torch .no_grad ():
740+ with torch .inference_mode ():
742741 x_self_cond = self .model_predictions (x , t ).pred_x_start
743742 x_self_cond .detach_ ()
744743
@@ -829,7 +828,9 @@ def __init__(
829828 convert_image_to = None ,
830829 calculate_fid = True ,
831830 inception_block_idx = 2048 ,
832- max_grad_norm = 1.
831+ max_grad_norm = 1. ,
832+ num_fid_samples = 50000 ,
833+ save_best_and_latest_only = False
833834 ):
834835 super ().__init__ ()
835836
@@ -845,21 +846,15 @@ def __init__(
845846 self .model = diffusion_model
846847 self .channels = diffusion_model .channels
847848
848- # InceptionV3 for fid-score computation
849-
850- self .inception_v3 = None
851-
852- if calculate_fid :
853- assert inception_block_idx in InceptionV3 .BLOCK_INDEX_BY_DIM
854- block_idx = InceptionV3 .BLOCK_INDEX_BY_DIM [inception_block_idx ]
855- self .inception_v3 = InceptionV3 ([block_idx ])
856- self .inception_v3 .to (self .device )
857-
858849 # sampling and training hyperparameters
859850
860851 assert has_int_squareroot (num_samples ), 'number of samples must have an integer square root'
861852 self .num_samples = num_samples
862853 self .save_and_sample_every = save_and_sample_every
854+ if save_best_and_latest_only :
855+ assert calculate_fid , "`calculate_fid` must be True to provide a means for model evaluation for `save_best_and_latest_only`."
856+ self .best_fid = 1e10 # infinite
857+ self .save_best_and_latest_only = save_best_and_latest_only
863858
864859 self .batch_size = train_batch_size
865860 self .gradient_accumulate_every = gradient_accumulate_every
@@ -898,6 +893,27 @@ def __init__(
898893
899894 self .model , self .opt = self .accelerator .prepare (self .model , self .opt )
900895
896+ # FID-score computation
897+
898+ if calculate_fid :
899+ self .calculate_fid = True
900+ if not self .model .is_ddim_sampling :
901+ self .accelerator .print (
902+ "WARNING: Robust FID computation requires a lot of generated samples and can therefore be very time consuming." \
903+ "Consider using DDIM sampling to save time."
904+ )
905+ self .fid_scorer = FIDEvaluation (
906+ batch_size = self .batch_size ,
907+ dl = self .dl ,
908+ sampler = self .ema .ema_model ,
909+ channels = self .channels ,
910+ accelerator = self .accelerator ,
911+ stats_dir = results_folder ,
912+ device = self .device ,
913+ num_fid_samples = num_fid_samples ,
914+ inception_block_idx = inception_block_idx
915+ )
916+
901917 @property
902918 def device (self ):
903919 return self .accelerator .device
@@ -937,31 +953,6 @@ def load(self, milestone):
937953 if exists (self .accelerator .scaler ) and exists (data ['scaler' ]):
938954 self .accelerator .scaler .load_state_dict (data ['scaler' ])
939955
940- @torch .no_grad ()
941- def calculate_activation_statistics (self , samples ):
942- assert exists (self .inception_v3 )
943-
944- features = self .inception_v3 (samples )[0 ]
945- features = rearrange (features , '... 1 1 -> ...' ).cpu ().numpy ()
946-
947- mu = np .mean (features , axis = 0 )
948- sigma = np .cov (features , rowvar = False )
949- return mu , sigma
950-
951- def fid_score (self , real_samples , fake_samples ):
952-
953- if self .channels == 1 :
954- real_samples , fake_samples = map (lambda t : repeat (t , 'b 1 ... -> b c ...' , c = 3 ), (real_samples , fake_samples ))
955-
956- min_batch = min (real_samples .shape [0 ], fake_samples .shape [0 ])
957- real_samples , fake_samples = map (lambda t : t [:min_batch ], (real_samples , fake_samples ))
958-
959- m1 , s1 = self .calculate_activation_statistics (real_samples )
960- m2 , s2 = self .calculate_activation_statistics (fake_samples )
961-
962- fid_value = calculate_frechet_distance (m1 , s1 , m2 , s2 )
963- return fid_value
964-
965956 def train (self ):
966957 accelerator = self .accelerator
967958 device = accelerator .device
@@ -999,21 +990,27 @@ def train(self):
999990 if self .step != 0 and self .step % self .save_and_sample_every == 0 :
1000991 self .ema .ema_model .eval ()
1001992
1002- with torch .no_grad ():
993+ with torch .inference_mode ():
1003994 milestone = self .step // self .save_and_sample_every
1004995 batches = num_to_groups (self .num_samples , self .batch_size )
1005996 all_images_list = list (map (lambda n : self .ema .ema_model .sample (batch_size = n ), batches ))
1006997
1007998 all_images = torch .cat (all_images_list , dim = 0 )
1008999
10091000 utils .save_image (all_images , str (self .results_folder / f'sample-{ milestone } .png' ), nrow = int (math .sqrt (self .num_samples )))
1010- self .save (milestone )
10111001
10121002 # whether to calculate fid
10131003
1014- if exists ( self .inception_v3 ) :
1015- fid_score = self .fid_score (real_samples = data , fake_samples = all_images )
1004+ if self .calculate_fid :
1005+ fid_score = self .fid_scorer . fid_score ()
10161006 accelerator .print (f'fid_score: { fid_score } ' )
1007+ if self .save_best_and_latest_only :
1008+ if self .best_fid > fid_score :
1009+ self .best_fid = fid_score
1010+ self .save ("best" )
1011+ self .save ("latest" )
1012+ else :
1013+ self .save (milestone )
10171014
10181015 pbar .update (1 )
10191016
0 commit comments