diff --git a/models/base_model.py b/models/base_model.py index 6de961b51a2..c24ba207973 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -95,6 +95,13 @@ def eval(self): net = getattr(self, 'net' + name) net.eval() + def train(self): + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + net.train() + + def test(self): """Forward function used in test time. diff --git a/train.py b/train.py index 2852652df82..f3ab4e267f2 100644 --- a/train.py +++ b/train.py @@ -19,15 +19,35 @@ See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md """ import time +import os from options.train_options import TrainOptions +from options.test_options import TestOptions from data import create_dataset from models import create_model -from util.visualizer import Visualizer +from util import html +from util.visualizer import Visualizer, save_images +from pytorch_fid.fid_score import calculate_fid_given_paths if __name__ == '__main__': - opt = TrainOptions().parse() # get training options + opt = TrainOptions().parse() # get training options + val_opts = TestOptions().parse() + val_opts.phase = 'val' + val_opts.num_threads = 0 # test code only supports num_threads = 0 + val_opts.batch_size = 1 # test code only supports batch_size = 1 + val_opts.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. + val_opts.no_flip = True # no flip; comment this line if results on flipped images are needed. + val_opts.display_id = -1 + dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options - dataset_size = len(dataset) # get the number of images in the dataset. + val_dataset = create_dataset(val_opts) # create a dataset given opt.dataset_mode and other options + web_dir = os.path.join(val_opts.results_dir, val_opts.name, + '{}_{}'.format(val_opts.phase, val_opts.epoch)) # define the website directory + if opt.load_iter > 0: # load_iter is 0 by default + web_dir = '{:s}_iter{:d}'.format(web_dir, opt.load_iter) + print('creating web directory', web_dir) + webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch)) + + dataset_size = len(dataset) # get the number of images in the dataset. print('The number of training images = %d' % dataset_size) model = create_model(opt) # create a model given opt.model and other options @@ -74,4 +94,31 @@ model.save_networks('latest') model.save_networks(epoch) - print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) + if epoch % opt.val_metric_freq == 0: + print('Evaluating FID for validation set at epoch %d, iters %d, at dataset %s' % ( + epoch, total_iters, opt.name)) + model.eval() + for i, data in enumerate(val_dataset): + model.set_input(data) # unpack data from data loader + model.test() # run inference + + visuals = model.get_current_visuals() # get image results + if opt.direction == 'BtoA': + visuals = {'fake_B': visuals['fake_B']} + else: + visuals = {'fake_A': visuals['fake_A']} + + img_path = model.get_image_paths() # get image paths + if i % 5 == 0: # save images to an HTML file + print('processing (%04d)-th image... %s' % (i, img_path)) + save_images(webpage, visuals, img_path, aspect_ratio=val_opts.aspect_ratio, + width=val_opts.display_winsize) + fid_value = calculate_fid_given_paths( + paths=('./results/{d}/val_latest/images/'.format(d=opt.name), '{d}/val'.format(d=opt.dataroot)), + batch_size=64, cuda=True, dims=2048) + visualizer.print_current_fid(epoch, fid_value) + visualizer.plot_current_fid(epoch, fid_value) + model.train() + + print('End of epoch %d / %d \t Time Taken: %d sec' % ( + epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) diff --git a/util/visualizer.py b/util/visualizer.py index 9736b5c3049..3aea9b26125 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -80,9 +80,14 @@ def __init__(self, opt): util.mkdirs([self.web_dir, self.img_dir]) # create a logging file to store training losses self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + self.fid_log_name = os.path.join(opt.checkpoints_dir, opt.name, 'fid_log.txt') + with open(self.log_name, "a") as log_file: now = time.strftime("%c") log_file.write('================ Training Loss (%s) ================\n' % now) + with open(self.fid_log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Validation FID (%s) ================\n' % now) def reset(self): """Reset the self.saved status""" @@ -176,6 +181,29 @@ def display_current_results(self, visuals, epoch, save_result): webpage.add_images(ims, txts, links, width=self.win_size) webpage.save() + def plot_current_fid(self, epoch, fid): + """display the current fid on visdom display + + Parameters: + epoch (int) -- current epoch + fid (float) -- validation fid + """ + if not hasattr(self, 'fid_plot_data'): + self.fid_plot_data = {'X': [], 'Y': []} + self.fid_plot_data['X'].append(epoch) + self.fid_plot_data['Y'].append(fid) + try: + self.vis.line( + X=np.array(self.fid_plot_data['X']), + Y=np.array(self.fid_plot_data['Y']), + opts={ + 'title': self.name + ' fid over time', + 'xlabel': 'epoch', + 'ylabel': 'fid'}, + win=self.display_id + 4) + except VisdomExceptionBase: + self.create_visdom_connections() + def plot_current_losses(self, epoch, counter_ratio, losses): """display the current losses on visdom display: dictionary of error labels and values @@ -219,3 +247,16 @@ def print_current_losses(self, epoch, iters, losses, t_comp, t_data): print(message) # print the message with open(self.log_name, "a") as log_file: log_file.write('%s\n' % message) # save the message + + def print_current_fid(self, epoch, fid): + """print current fid on console; also save the fid to the disk + + Parameters: + epoch (int) -- current epoch + fid (float) - fid metric + """ + message = '(epoch: %d, fid: %.3f) ' % (epoch, fid) + + print(message) # print the message + with open(self.fid_log_name, "a") as log_file: + log_file.write('%s\n' % message) # save the message