From 6168e1a82a0a7ae4279e4e27570832c99feb67f3 Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Thu, 9 Jul 2020 14:11:49 +0300 Subject: [PATCH 01/19] Add NVIDIA apex support and checkpointing memory optimization (https://pytorch.org/docs/stable/checkpoint.html) --- models/cycle_gan_model.py | 40 ++++++++++++++++++++++++++++++++------- models/pix2pix_model.py | 30 ++++++++++++++++++++++++++--- options/base_options.py | 7 ++++++- 3 files changed, 66 insertions(+), 11 deletions(-) diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 15bb72d8ddc..f536503c8f8 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -3,7 +3,12 @@ from util.image_pool import ImagePool from .base_model import BaseModel from . import networks +from torch.utils.checkpoint import checkpoint +try: + from apex import amp +except ImportError: + print("Please install NVIDIA Apex for safe mixed precision if you want to use non default --opt_level") class CycleGANModel(BaseModel): """ @@ -96,6 +101,10 @@ def __init__(self, opt): self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) + if opt.apex: + [self.netG_A, self.netG_B, self.netD_A, self.netD_B], [self.optimizer_G, self.optimizer_D] = amp.initialize( + [self.netG_A, self.netG_B, self.netD_A, self.netD_B], [self.optimizer_G, self.optimizer_D], opt_level=opt.opt_level, num_losses=3) + def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. @@ -112,11 +121,17 @@ def set_input(self, input): def forward(self): """Run forward pass; called by both functions and .""" self.fake_B = self.netG_A(self.real_A) # G_A(A) - self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) + if not self.opt.checkpointing: + self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) + else: + self.rec_A = checkpoint(self.netG_B, self.fake_B) self.fake_A = self.netG_B(self.real_B) # G_B(B) - self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) + if not self.opt.checkpointing: + self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) + else: + self.rec_B = checkpoint(self.netG_A, self.fake_A) - def backward_D_basic(self, netD, real, fake): + def backward_D_basic(self, netD, real, fake, loss_id): """Calculate GAN loss for the discriminator Parameters: @@ -135,18 +150,23 @@ def backward_D_basic(self, netD, real, fake): loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss and calculate gradients loss_D = (loss_D_real + loss_D_fake) * 0.5 - loss_D.backward() + if self.opt.apex: + with amp.scale_loss(loss_D, self.optimizer_D, loss_id=loss_id) as loss_D_scaled: + loss_D_scaled.backward() + else: + loss_D.backward() + return loss_D def backward_D_A(self): """Calculate GAN loss for discriminator D_A""" fake_B = self.fake_B_pool.query(self.fake_B) - self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) + self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B, loss_id=0) def backward_D_B(self): """Calculate GAN loss for discriminator D_B""" fake_A = self.fake_A_pool.query(self.fake_A) - self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A, loss_id=1) def backward_G(self): """Calculate the loss for generators G_A and G_B""" @@ -175,7 +195,13 @@ def backward_G(self): self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss and calculate gradients self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B - self.loss_G.backward() + + if self.opt.apex: + with amp.scale_loss(self.loss_G, self.optimizer_G, loss_id=2) as loss_G_scaled: + loss_G_scaled.backward() + else: + self.loss_G.backward() + def optimize_parameters(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 939eb887ee3..55692479918 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -1,7 +1,14 @@ import torch +from torch.utils.checkpoint import checkpoint + from .base_model import BaseModel from . import networks +try: + from apex import amp +except ImportError: + print("Please install NVIDIA Apex for safe mixed precision if you want to use non default --opt_level") + class Pix2PixModel(BaseModel): """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. @@ -70,6 +77,10 @@ def __init__(self, opt): self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) + if opt.apex: + [self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = amp.initialize( + [self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=opt.opt_level, num_losses=2) + def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. @@ -85,7 +96,10 @@ def set_input(self, input): def forward(self): """Run forward pass; called by both functions and .""" - self.fake_B = self.netG(self.real_A) # G(A) + if not self.opt.checkpointing: + self.fake_B = self.netG(self.real_A) # G(A) + else: + self.fake_B = checkpoint(self.netG, self.real_A) def backward_D(self): """Calculate GAN loss for the discriminator""" @@ -99,7 +113,12 @@ def backward_D(self): self.loss_D_real = self.criterionGAN(pred_real, True) # combine loss and calculate gradients self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 - self.loss_D.backward() + + if self.opt.apex: + with amp.scale_loss(self.loss_D, self.optimizer_D, loss_id=0) as loss_D_scaled: + loss_D_scaled.backward() + else: + self.loss_D.backward() def backward_G(self): """Calculate GAN and L1 loss for the generator""" @@ -111,7 +130,12 @@ def backward_G(self): self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 # combine loss and calculate gradients self.loss_G = self.loss_G_GAN + self.loss_G_L1 - self.loss_G.backward() + + if self.opt.apex: + with amp.scale_loss(self.loss_G, self.optimizer_G, loss_id=1) as loss_G_scaled: + loss_G_scaled.backward() + else: + self.loss_G.backward() def optimize_parameters(self): self.forward() # compute fake images: G(A) diff --git a/options/base_options.py b/options/base_options.py index afb5d0852d1..dbed281be55 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -54,6 +54,11 @@ def initialize(self, parser): parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') + + parser.add_argument('--checkpointing', default=False, type=bool, + help='if true, it applies gradient checkpointing, saves memory but it makes the training slower') + parser.add_argument('--opt_level', default='O0', help='amp opt_level, default="O0" equals fp32 training') + self.initialized = True return parser @@ -114,7 +119,7 @@ def parse(self): """Parse our options, create checkpoints directory suffix, and set up gpu device.""" opt = self.gather_options() opt.isTrain = self.isTrain # train or test - + opt.apex = opt.opt_level != "O0" # process opt.suffix if opt.suffix: suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' From 653b5781080f09c6824b90e661bf4c99485b7331 Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Thu, 9 Jul 2020 15:15:27 +0300 Subject: [PATCH 02/19] Add NVIDIA apex support and checkpointing memory optimization (https://pytorch.org/docs/stable/checkpoint.html) Fix data_parallel order --- models/base_model.py | 8 ++++++++ models/cycle_gan_model.py | 3 +++ models/networks.py | 1 - models/pix2pix_model.py | 2 ++ models/template_model.py | 2 ++ options/base_options.py | 2 +- 6 files changed, 16 insertions(+), 2 deletions(-) diff --git a/models/base_model.py b/models/base_model.py index 9cfb761897e..ba7ce376fb9 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -227,3 +227,11 @@ def set_requires_grad(self, nets, requires_grad=False): if net is not None: for param in net.parameters(): param.requires_grad = requires_grad + + def make_data_parallel(self): + """Make models data parallel""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + net = torch.nn.DataParallel(net, self.gpu_ids) # multi-GPUs + setattr(self, 'net' + name, net) diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index f536503c8f8..69b21059adb 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -105,6 +105,9 @@ def __init__(self, opt): [self.netG_A, self.netG_B, self.netD_A, self.netD_B], [self.optimizer_G, self.optimizer_D] = amp.initialize( [self.netG_A, self.netG_B, self.netD_A, self.netD_B], [self.optimizer_G, self.optimizer_D], opt_level=opt.opt_level, num_losses=3) + # need to be wrapped after amp.initialize + self.make_data_parallel() + def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. diff --git a/models/networks.py b/models/networks.py index b3a10c99c20..1bd134e5186 100644 --- a/models/networks.py +++ b/models/networks.py @@ -111,7 +111,6 @@ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): if len(gpu_ids) > 0: assert(torch.cuda.is_available()) net.to(gpu_ids[0]) - net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs init_weights(net, init_type, init_gain=init_gain) return net diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 55692479918..f0fdf9c31e1 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -81,6 +81,8 @@ def __init__(self, opt): [self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = amp.initialize( [self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=opt.opt_level, num_losses=2) + self.make_data_parallel() + def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. diff --git a/models/template_model.py b/models/template_model.py index 68cdaf6a9a2..45d3659ca4c 100644 --- a/models/template_model.py +++ b/models/template_model.py @@ -67,6 +67,8 @@ def __init__(self, opt): self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [self.optimizer] + # need to be wrapped after amp.initialize + self.make_data_parallel() # Our program will automatically call to define schedulers, load networks, and print networks def set_input(self, input): diff --git a/options/base_options.py b/options/base_options.py index dbed281be55..7f0b91ab4b2 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -55,7 +55,7 @@ def initialize(self, parser): parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') - parser.add_argument('--checkpointing', default=False, type=bool, + parser.add_argument('--checkpointing', action='store_true', help='if true, it applies gradient checkpointing, saves memory but it makes the training slower') parser.add_argument('--opt_level', default='O0', help='amp opt_level, default="O0" equals fp32 training') From 48ef29feaca5a181f0bcf0b10b2eb49f58057fd3 Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Fri, 10 Jul 2020 18:09:28 +0300 Subject: [PATCH 03/19] Add NVIDIA apex support and checkpointing memory optimization (https://pytorch.org/docs/stable/checkpoint.html) Disable checkpointing for pix2pix --- models/cycle_gan_model.py | 4 ++-- models/pix2pix_model.py | 5 +---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 69b21059adb..8db18f2f3ff 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -124,12 +124,12 @@ def set_input(self, input): def forward(self): """Run forward pass; called by both functions and .""" self.fake_B = self.netG_A(self.real_A) # G_A(A) - if not self.opt.checkpointing: + if not self.opt.checkpointing or not self.isTrain: self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) else: self.rec_A = checkpoint(self.netG_B, self.fake_B) self.fake_A = self.netG_B(self.real_B) # G_B(B) - if not self.opt.checkpointing: + if not self.opt.checkpointing or not self.isTrain: self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) else: self.rec_B = checkpoint(self.netG_A, self.fake_A) diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index f0fdf9c31e1..70b18a4de8c 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -98,10 +98,7 @@ def set_input(self, input): def forward(self): """Run forward pass; called by both functions and .""" - if not self.opt.checkpointing: - self.fake_B = self.netG(self.real_A) # G(A) - else: - self.fake_B = checkpoint(self.netG, self.real_A) + self.fake_B = self.netG(self.real_A) # G(A) def backward_D(self): """Calculate GAN loss for the discriminator""" From 60961635b21138dc36c23d8095bb866a74f53cb8 Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Fri, 10 Jul 2020 18:18:18 +0300 Subject: [PATCH 04/19] Add NVIDIA apex support and checkpointing memory optimization (https://pytorch.org/docs/stable/checkpoint.html) Minor fix --- models/test_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/models/test_model.py b/models/test_model.py index fe15f40176e..a9ab50064db 100644 --- a/models/test_model.py +++ b/models/test_model.py @@ -48,6 +48,7 @@ def __init__(self, opt): # assigns the model to self.netG_[suffix] so that it can be loaded # please see setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self. + self.make_data_parallel() def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. From d2a66809f8bf38229eec8965952e375fa5bef44d Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Sat, 11 Jul 2020 00:40:28 +0300 Subject: [PATCH 05/19] Add NVIDIA apex support and checkpointing memory optimization (https://pytorch.org/docs/stable/checkpoint.html) Refactor configs --- models/cycle_gan_model.py | 4 ++-- options/base_options.py | 7 +------ options/train_options.py | 9 +++++++++ 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 8db18f2f3ff..63b980d915d 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -124,12 +124,12 @@ def set_input(self, input): def forward(self): """Run forward pass; called by both functions and .""" self.fake_B = self.netG_A(self.real_A) # G_A(A) - if not self.opt.checkpointing or not self.isTrain: + if not self.isTrain or not self.opt.checkpointing: self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) else: self.rec_A = checkpoint(self.netG_B, self.fake_B) self.fake_A = self.netG_B(self.real_B) # G_B(B) - if not self.opt.checkpointing or not self.isTrain: + if not self.isTrain or not self.opt.checkpointing: self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) else: self.rec_B = checkpoint(self.netG_A, self.fake_A) diff --git a/options/base_options.py b/options/base_options.py index 7f0b91ab4b2..afb5d0852d1 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -54,11 +54,6 @@ def initialize(self, parser): parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') - - parser.add_argument('--checkpointing', action='store_true', - help='if true, it applies gradient checkpointing, saves memory but it makes the training slower') - parser.add_argument('--opt_level', default='O0', help='amp opt_level, default="O0" equals fp32 training') - self.initialized = True return parser @@ -119,7 +114,7 @@ def parse(self): """Parse our options, create checkpoints directory suffix, and set up gpu device.""" opt = self.gather_options() opt.isTrain = self.isTrain # train or test - opt.apex = opt.opt_level != "O0" + # process opt.suffix if opt.suffix: suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' diff --git a/options/train_options.py b/options/train_options.py index c8d5d2a92a9..575cd3cbbf5 100644 --- a/options/train_options.py +++ b/options/train_options.py @@ -36,5 +36,14 @@ def initialize(self, parser): parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') + # training optimizations + parser.add_argument('--checkpointing', action='store_true', + help='if true, it applies gradient checkpointing, saves memory but it makes the training slower') + parser.add_argument('--opt_level', default='O0', help='amp opt_level, default="O0" equals fp32 training') self.isTrain = True return parser + + def parse(self): + opt = BaseOptions.parse(self) + opt.apex = opt.opt_level != "O0" + return opt From bfa902ab49ef855aca42b8e3fe88e56dcb86e461 Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Thu, 16 Jul 2020 16:04:21 +0300 Subject: [PATCH 06/19] Fix CPU version --- models/base_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/models/base_model.py b/models/base_model.py index ba7ce376fb9..a387ec27173 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -230,6 +230,8 @@ def set_requires_grad(self, nets, requires_grad=False): def make_data_parallel(self): """Make models data parallel""" + if len(self.gpu_ids) == 0: + return for name in self.model_names: if isinstance(name, str): net = getattr(self, 'net' + name) From 41bc4ebaba2536226a4c2ab48b8b3a91a04cac7a Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Sat, 25 Jul 2020 02:34:45 +0300 Subject: [PATCH 07/19] Add ProGan model --- models/networks.py | 275 +++++++++++++++++++++++++++++++++++++++-- models/progan_model.py | 162 ++++++++++++++++++++++++ 2 files changed, 430 insertions(+), 7 deletions(-) create mode 100644 models/progan_model.py diff --git a/models/networks.py b/models/networks.py index 1bd134e5186..c96c49ff8f3 100644 --- a/models/networks.py +++ b/models/networks.py @@ -3,7 +3,8 @@ from torch.nn import init import functools from torch.optim import lr_scheduler - +from torch.nn import functional as F +from math import sqrt ############################################################################### # Helper Functions @@ -98,7 +99,7 @@ def init_func(m): # define the initialization function net.apply(init_func) # apply the initialization function -def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): +def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], init_weights_=True): """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights Parameters: net (network) -- the network to be initialized @@ -111,11 +112,13 @@ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): if len(gpu_ids) > 0: assert(torch.cuda.is_available()) net.to(gpu_ids[0]) - init_weights(net, init_type, init_gain=init_gain) + if init_weights_: + init_weights(net, init_type, init_gain=init_gain) return net -def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]): +def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], + init_weights=True): """Create a generator Parameters: @@ -153,12 +156,15 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) elif netG == 'unet_256': net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + elif netG == 'progan': + net = GeneratorProGan(input_code_dim=128) else: raise NotImplementedError('Generator model name [%s] is not recognized' % netG) - return init_net(net, init_type, init_gain, gpu_ids) + return init_net(net, init_type, init_gain, gpu_ids, init_weights_=init_weights) -def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]): +def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[], + init_weights=True): """Create a discriminator Parameters: @@ -197,9 +203,11 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal' net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) elif netD == 'pixel': # classify if each pixel is real or fake net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) + elif netD == 'progan': + net = DiscriminatorProGan() else: raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) - return init_net(net, init_type, init_gain, gpu_ids) + return init_net(net, init_type, init_gain, gpu_ids, init_weights_=init_weights) ############################################################################## @@ -612,3 +620,256 @@ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): def forward(self, input): """Standard forward.""" return self.net(input) + + +class EqualLR: + def __init__(self, name): + self.name = name + + def compute_weight(self, module): + weight = getattr(module, self.name + '_orig') + fan_in = weight.data.size(1) * weight.data[0][0].numel() + + return weight * sqrt(2 / fan_in) + + @staticmethod + def apply(module, name): + fn = EqualLR(name) + + weight = getattr(module, name) + del module._parameters[name] + module.register_parameter(name + '_orig', nn.Parameter(weight.data)) + module.register_forward_pre_hook(fn) + + return fn + + def __call__(self, module, input): + weight = self.compute_weight(module) + setattr(module, self.name, weight) + + +def equal_lr(module, name='weight'): + EqualLR.apply(module, name) + + return module + + +class PixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input / torch.sqrt(torch.mean(input ** 2, dim=1, keepdim=True) + + 1e-8) + + +class EqualConv2d(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + conv = nn.Conv2d(*args, **kwargs) + conv.weight.data.normal_() + conv.bias.data.zero_() + self.conv = equal_lr(conv) + + def forward(self, input): + return self.conv(input) + + +class EqualConvTranspose2d(nn.Module): + ### additional module for OOGAN usage + def __init__(self, *args, **kwargs): + super().__init__() + + conv = nn.ConvTranspose2d(*args, **kwargs) + conv.weight.data.normal_() + conv.bias.data.zero_() + self.conv = equal_lr(conv) + + def forward(self, input): + return self.conv(input) + + +class EqualLinear(nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + + linear = nn.Linear(in_dim, out_dim) + linear.weight.data.normal_() + linear.bias.data.zero_() + + self.linear = equal_lr(linear) + + def forward(self, input): + return self.linear(input) + + +class ConvBlockProGan(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, padding, kernel_size2=None, padding2=None, + pixel_norm=True): + super().__init__() + + pad1 = padding + pad2 = padding + if padding2 is not None: + pad2 = padding2 + + kernel1 = kernel_size + kernel2 = kernel_size + if kernel_size2 is not None: + kernel2 = kernel_size2 + + convs = [EqualConv2d(in_channel, out_channel, kernel1, padding=pad1)] + if pixel_norm: + convs.append(PixelNorm()) + convs.append(nn.LeakyReLU(0.1)) + convs.append(EqualConv2d(out_channel, out_channel, kernel2, padding=pad2)) + if pixel_norm: + convs.append(PixelNorm()) + convs.append(nn.LeakyReLU(0.1)) + + self.conv = nn.Sequential(*convs) + + def forward(self, input): + out = self.conv(input) + return out + + +def upscale(feat): + return F.interpolate(feat, scale_factor=2, mode='bilinear', align_corners=False) + + +class GeneratorProGan(nn.Module): + def __init__(self, input_code_dim=128, in_channel=128, pixel_norm=True, tanh=True): + super().__init__() + self.input_dim = input_code_dim + self.tanh = tanh + # self.input_layer = nn.Sequential( + # EqualConvTranspose2d(input_code_dim, in_channel, 4, 1, 0), + # PixelNorm(), + # nn.LeakyReLU(0.1)) + + self.input_layer = nn.Sequential( + EqualConv2d(input_code_dim, in_channel, 3, padding=1), + PixelNorm(), + nn.LeakyReLU(0.1)) + + self.progression_4 = ConvBlockProGan(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm) + self.progression_8 = ConvBlockProGan(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm) + self.progression_16 = ConvBlockProGan(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm) + self.progression_32 = ConvBlockProGan(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm) + self.progression_64 = ConvBlockProGan(in_channel, in_channel // 2, 3, 1, pixel_norm=pixel_norm) + self.progression_128 = ConvBlockProGan(in_channel // 2, in_channel // 4, 3, 1, pixel_norm=pixel_norm) + self.progression_256 = ConvBlockProGan(in_channel // 4, in_channel // 4, 3, 1, pixel_norm=pixel_norm) + + self.to_rgb_8 = EqualConv2d(in_channel, 3, 1) + self.to_rgb_16 = EqualConv2d(in_channel, 3, 1) + self.to_rgb_32 = EqualConv2d(in_channel, 3, 1) + self.to_rgb_64 = EqualConv2d(in_channel // 2, 3, 1) + self.to_rgb_128 = EqualConv2d(in_channel // 4, 3, 1) + self.to_rgb_256 = EqualConv2d(in_channel // 4, 3, 1) + + self.max_step = 6 + + def progress(self, feat, module): + out = F.interpolate(feat, scale_factor=2, mode='bilinear', align_corners=False) + out = module(out) + return out + + def output(self, feat1, feat2, module1, module2, alpha): + if 0 <= alpha < 1: + skip_rgb = upscale(module1(feat1)) + out = (1 - alpha) * skip_rgb + alpha * module2(feat2) + else: + out = module2(feat2) + if self.tanh: + return torch.tanh(out) + return out + + def forward(self, input, step=0, alpha=-1): + if step > self.max_step: + step = self.max_step + # out_4 = self.input_layer(input.view(-1, self.input_dim, 1, 1)) + out_4 = self.input_layer(input) + out_4 = self.progression_4(out_4) + out_8 = self.progress(out_4, self.progression_8) + if step == 1: + if self.tanh: + return torch.tanh(self.to_rgb_8(out_8)) + return self.to_rgb_8(out_8) + + out_16 = self.progress(out_8, self.progression_16) + if step == 2: + return self.output(out_8, out_16, self.to_rgb_8, self.to_rgb_16, alpha) + + out_32 = self.progress(out_16, self.progression_32) + if step == 3: + return self.output(out_16, out_32, self.to_rgb_16, self.to_rgb_32, alpha) + + out_64 = self.progress(out_32, self.progression_64) + if step == 4: + return self.output(out_32, out_64, self.to_rgb_32, self.to_rgb_64, alpha) + + out_128 = self.progress(out_64, self.progression_128) + if step == 5: + return self.output(out_64, out_128, self.to_rgb_64, self.to_rgb_128, alpha) + + out_256 = self.progress(out_128, self.progression_256) + if step == 6: + return self.output(out_128, out_256, self.to_rgb_128, self.to_rgb_256, alpha) + + +class DiscriminatorProGan(nn.Module): + def __init__(self, feat_dim=128): + super().__init__() + + self.progression = nn.ModuleList([ConvBlockProGan(feat_dim // 4, feat_dim // 4, 3, 1), + ConvBlockProGan(feat_dim // 4, feat_dim // 2, 3, 1), + ConvBlockProGan(feat_dim // 2, feat_dim, 3, 1), + ConvBlockProGan(feat_dim, feat_dim, 3, 1), + ConvBlockProGan(feat_dim, feat_dim, 3, 1), + ConvBlockProGan(feat_dim, feat_dim, 3, 1), + ConvBlockProGan(feat_dim + 1, feat_dim, 3, 1, 4, 0)]) + + self.from_rgb = nn.ModuleList([EqualConv2d(3, feat_dim // 4, 1), + EqualConv2d(3, feat_dim // 4, 1), + EqualConv2d(3, feat_dim // 2, 1), + EqualConv2d(3, feat_dim, 1), + EqualConv2d(3, feat_dim, 1), + EqualConv2d(3, feat_dim, 1), + EqualConv2d(3, feat_dim, 1)]) + + self.n_layer = len(self.progression) + + self.linear = EqualLinear(feat_dim, 1) + + def forward(self, input, step=0, alpha=-1): + for i in range(step, -1, -1): + index = self.n_layer - i - 1 + + if i == step: + out = self.from_rgb[index](input) + + if i == 0: + out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-8) + mean_std = out_std.mean() + mean_std = mean_std.expand(out.size(0), 1, 4, 4) + out = torch.cat([out, mean_std], 1) + + out = self.progression[index](out) + + if i > 0: + # out = F.avg_pool2d(out, 2) + out = F.interpolate(out, scale_factor=0.5, mode='bilinear', align_corners=False) + + if i == step and 0 <= alpha < 1: + # skip_rgb = F.avg_pool2d(input, 2) + skip_rgb = F.interpolate(input, scale_factor=0.5, mode='bilinear', align_corners=False) + skip_rgb = self.from_rgb[index + 1](skip_rgb) + out = (1 - alpha) * skip_rgb + alpha * out + + out = out.squeeze(2).squeeze(2) + # print(input.size(), out.size(), step) + out = self.linear(out) + + return out diff --git a/models/progan_model.py b/models/progan_model.py new file mode 100644 index 00000000000..0f0323b4047 --- /dev/null +++ b/models/progan_model.py @@ -0,0 +1,162 @@ +import torch + +from .base_model import BaseModel +from . import networks +from torch.nn import functional as F +try: + from apex import amp +except ImportError: + print("Please install NVIDIA Apex for safe mixed precision if you want to use non default --opt_level") + + +class ProGanModel(BaseModel): + + @staticmethod + def modify_commandline_options(parser, is_train=True): + + parser.set_defaults(netG='progan', netD='progan', dataset_mode='single', beta1=0.) + + return parser + + def accumulate(self, decay=0.999): + par1 = dict(self.netG.named_parameters()) + par2 = dict(self.netC.named_parameters()) + + for k in par1.keys(): + par1[k].data.mul_(decay).add_(1 - decay, par2[k].data) + + def __init__(self, opt): + + BaseModel.__init__(self, opt) + # specify the training losses you want to print out. The training/test scripts will call + self.loss_names = ['G_GAN', 'D_real', 'D_fake'] + # specify the images you want to save/display. The training/test scripts will call + self.visual_names = ['fake_B', 'real_B'] + # specify the models you want to save to the disk. The training/test scripts will call and + if self.isTrain: + self.model_names = ['G', 'D', 'C'] + else: # during test time, only load G + self.model_names = ['G'] + # define networks (both generator and discriminator) + self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, + not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, + init_weights=False) + self.netC = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, + not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, + init_weights=False) + + if self.isTrain: + self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, + opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, + init_weights=False) + + assert opt.beta1 == 0 + if self.isTrain: + # define loss functions + self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) + self.criterionL1 = torch.nn.L1Loss() + # initialize optimizers; schedulers will be automatically created by function . + self.optimizer_C = torch.optim.Adam(self.netC.parameters(), lr=opt.lr, betas=(opt.beta1, 0.99)) + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.99)) + self.optimizers.append(self.optimizer_C) + self.optimizers.append(self.optimizer_D) + + if opt.apex: + [self.netC, self.netD], [self.optimizer_C, self.optimizer_D] = amp.initialize( + [self.netC, self.netD], [self.optimizer_C, self.optimizer_D], opt_level=opt.opt_level, num_losses=2) + + self.make_data_parallel() + + # inner counters + self.total_steps = opt.n_epochs + opt.n_epochs_decay + 1 + assert self.total_steps > 12 + assert self.opt.crop_size % 2 ** 6 == 0 + self.step = 1 + self.iter = 0 + self.alpha = 0. + + # set fusing + self.netG.eval() + self.accumulate(0) + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): include the data itself and its metadata information. + + The option 'direction' can be used to swap images in domain A and domain B. + """ + AtoB = self.opt.direction == 'AtoB' + self.real_B = input['A' if AtoB else 'B'].to(self.device) + self.real_B = F.interpolate(self.real_B, size=(4 * 2 ** self.step, 4 * 2 ** self.step), mode='bilinear') + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + def forward(self): + """Run forward pass; called by both functions and .""" + batch_size = self.real_B.size(0) + z = torch.randn((batch_size, 128, self.opt.crop_size // (2 ** 6), self.opt.crop_size // (2 ** 6)), device=self.device) + self.fake_B = self.netC(z, step=self.step, alpha=self.alpha) + + def backward_D(self): + """Calculate GAN loss for the discriminator""" + # Fake; stop backprop to the generator by detaching fake_B + fake_B = self.fake_B + pred_fake = self.netD(fake_B.detach(), step=self.step, alpha=self.alpha) + self.loss_D_fake = self.criterionGAN(pred_fake, False) + # Real + real_B = self.real_B + pred_real = self.netD(real_B, step=self.step, alpha=self.alpha) + self.loss_D_real = self.criterionGAN(pred_real, True) + # combine loss and calculate gradients + self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 + + if self.opt.apex: + with amp.scale_loss(self.loss_D, self.optimizer_D, loss_id=0) as loss_D_scaled: + loss_D_scaled.backward() + else: + self.loss_D.backward() + + def backward_G(self): + """Calculate GAN loss for the generator""" + # First, G(A) should fake the discriminator + fake_B = self.fake_B + pred_fake = self.netD(fake_B, step=self.step, alpha=self.alpha) + self.loss_G_GAN = self.criterionGAN(pred_fake, True) + self.loss_G = self.loss_G_GAN + + if self.opt.apex: + with amp.scale_loss(self.loss_G, self.optimizer_C, loss_id=1) as loss_G_scaled: + loss_G_scaled.backward() + else: + self.loss_G.backward() + + def optimize_parameters(self): + self.forward() # compute fake images: G(A) + # update D + self.set_requires_grad(self.netD, True) # enable backprop for D + self.optimizer_D.zero_grad() # set D's gradients to zero + self.backward_D() # calculate gradients for D + self.optimizer_D.step() # update D's weights + # update generator C + self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G + self.optimizer_C.zero_grad() # set G's gradients to zero + self.backward_G() # calculate graidents for G + self.optimizer_C.step() # udpate G's weights + self.accumulate() # fuse params + + def update_inners_counters(self): + self.iter += 1 + self.alpha = min(1, (2 / (self.total_steps // 6)) * self.iter) + if self.iter > self.total_steps // 6: + self.alpha = 0 + self.iter = 0 + self.step += 1 + + if self.step > 6: + self.alpha = 1 + self.step = 6 + + def update_learning_rate(self): + super(ProGanModel, self).update_learning_rate() + self.update_inners_counters() From 2142ba0499ba253fa2999d0c99cd5ea55fb9ec34 Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Sat, 25 Jul 2020 11:54:44 +0300 Subject: [PATCH 08/19] Some refactorings --- models/networks.py | 26 +++++++++++-------- models/progan_model.py | 58 ++++++++++++++++++++++++++++-------------- 2 files changed, 54 insertions(+), 30 deletions(-) diff --git a/models/networks.py b/models/networks.py index c96c49ff8f3..c5ca793cb6b 100644 --- a/models/networks.py +++ b/models/networks.py @@ -157,7 +157,7 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in elif netG == 'unet_256': net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) elif netG == 'progan': - net = GeneratorProGan(input_code_dim=128) + net = GeneratorProGan(input_code_dim=input_nc, in_channel=ngf) else: raise NotImplementedError('Generator model name [%s] is not recognized' % netG) return init_net(net, init_type, init_gain, gpu_ids, init_weights_=init_weights) @@ -204,7 +204,7 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal' elif netD == 'pixel': # classify if each pixel is real or fake net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) elif netD == 'progan': - net = DiscriminatorProGan() + net = DiscriminatorProGan(feat_dim=ndf, in_dim=input_nc) else: raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) return init_net(net, init_type, init_gain, gpu_ids, init_weights_=init_weights) @@ -750,9 +750,13 @@ def __init__(self, input_code_dim=128, in_channel=128, pixel_norm=True, tanh=Tru # nn.LeakyReLU(0.1)) self.input_layer = nn.Sequential( + EqualConv2d(input_code_dim, input_code_dim, 3, padding=1), + PixelNorm(), + nn.LeakyReLU(0.1), EqualConv2d(input_code_dim, in_channel, 3, padding=1), PixelNorm(), - nn.LeakyReLU(0.1)) + nn.LeakyReLU(0.1) + ) self.progression_4 = ConvBlockProGan(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm) self.progression_8 = ConvBlockProGan(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm) @@ -820,7 +824,7 @@ def forward(self, input, step=0, alpha=-1): class DiscriminatorProGan(nn.Module): - def __init__(self, feat_dim=128): + def __init__(self, feat_dim=128, in_dim=3): super().__init__() self.progression = nn.ModuleList([ConvBlockProGan(feat_dim // 4, feat_dim // 4, 3, 1), @@ -831,13 +835,13 @@ def __init__(self, feat_dim=128): ConvBlockProGan(feat_dim, feat_dim, 3, 1), ConvBlockProGan(feat_dim + 1, feat_dim, 3, 1, 4, 0)]) - self.from_rgb = nn.ModuleList([EqualConv2d(3, feat_dim // 4, 1), - EqualConv2d(3, feat_dim // 4, 1), - EqualConv2d(3, feat_dim // 2, 1), - EqualConv2d(3, feat_dim, 1), - EqualConv2d(3, feat_dim, 1), - EqualConv2d(3, feat_dim, 1), - EqualConv2d(3, feat_dim, 1)]) + self.from_rgb = nn.ModuleList([EqualConv2d(in_dim, feat_dim // 4, 1), + EqualConv2d(in_dim, feat_dim // 4, 1), + EqualConv2d(in_dim, feat_dim // 2, 1), + EqualConv2d(in_dim, feat_dim, 1), + EqualConv2d(in_dim, feat_dim, 1), + EqualConv2d(in_dim, feat_dim, 1), + EqualConv2d(in_dim, feat_dim, 1)]) self.n_layer = len(self.progression) diff --git a/models/progan_model.py b/models/progan_model.py index 0f0323b4047..56ba8e4a4b1 100644 --- a/models/progan_model.py +++ b/models/progan_model.py @@ -3,6 +3,7 @@ from .base_model import BaseModel from . import networks from torch.nn import functional as F + try: from apex import amp except ImportError: @@ -10,24 +11,26 @@ class ProGanModel(BaseModel): + """ + This is an implementation of the paper "Progressive Growing of GANs": https://arxiv.org/abs/1710.10196. + Model requires dataset of type dataset_mode='single', generator netG='progan', discriminator netD='progan'. + ngf and ndf controlls dimensions of the backbone. + Network G is a master-generator (for eval) and network C (stands for current) is a current trainable generator. + """ @staticmethod def modify_commandline_options(parser, is_train=True): - parser.set_defaults(netG='progan', netD='progan', dataset_mode='single', beta1=0.) - + parser.set_defaults(netG='progan', netD='progan', dataset_mode='single', beta1=0., ngf=512, ndf=512) + parser.add_argument('--z_dim', type=int, default=32, help='random noise dim') + parser.add_argument('--max_steps', type=int, default=6, help='steps of growing') return parser - def accumulate(self, decay=0.999): - par1 = dict(self.netG.named_parameters()) - par2 = dict(self.netC.named_parameters()) - - for k in par1.keys(): - par1[k].data.mul_(decay).add_(1 - decay, par2[k].data) - def __init__(self, opt): BaseModel.__init__(self, opt) + self.z_dim = opt.z_dim + self.max_steps = opt.max_steps # specify the training losses you want to print out. The training/test scripts will call self.loss_names = ['G_GAN', 'D_real', 'D_fake'] # specify the images you want to save/display. The training/test scripts will call @@ -38,15 +41,15 @@ def __init__(self, opt): else: # during test time, only load G self.model_names = ['G'] # define networks (both generator and discriminator) - self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, + self.netG = networks.define_G(opt.z_dim, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, init_weights=False) - self.netC = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, + self.netC = networks.define_G(opt.z_dim, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, init_weights=False) if self.isTrain: - self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, + self.netD = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, init_weights=False) @@ -69,12 +72,13 @@ def __init__(self, opt): # inner counters self.total_steps = opt.n_epochs + opt.n_epochs_decay + 1 - assert self.total_steps > 12 - assert self.opt.crop_size % 2 ** 6 == 0 self.step = 1 self.iter = 0 self.alpha = 0. + assert self.total_steps > 12 + assert self.opt.crop_size % 2 ** self.max_steps == 0 + # set fusing self.netG.eval() self.accumulate(0) @@ -95,7 +99,9 @@ def set_input(self, input): def forward(self): """Run forward pass; called by both functions and .""" batch_size = self.real_B.size(0) - z = torch.randn((batch_size, 128, self.opt.crop_size // (2 ** 6), self.opt.crop_size // (2 ** 6)), device=self.device) + z = torch.randn((batch_size, self.z_dim, self.opt.crop_size // (2 ** self.max_steps), + self.opt.crop_size // (2 ** self.max_steps)), + device=self.device) self.fake_B = self.netC(z, step=self.step, alpha=self.alpha) def backward_D(self): @@ -146,16 +152,30 @@ def optimize_parameters(self): self.accumulate() # fuse params def update_inners_counters(self): + """ + Update counters of iterations + """ self.iter += 1 - self.alpha = min(1, (2 / (self.total_steps // 6)) * self.iter) - if self.iter > self.total_steps // 6: + self.alpha = min(1, (2 / (self.total_steps // self.max_steps)) * self.iter) + if self.iter > self.total_steps // self.max_steps: self.alpha = 0 self.iter = 0 self.step += 1 - if self.step > 6: + if self.step > self.max_steps: self.alpha = 1 - self.step = 6 + self.step = self.max_steps + + def accumulate(self, decay=0.999): + """ + Accumulate weights from self.C to self.G with decay + @param decay decay + """ + par1 = dict(self.netG.named_parameters()) + par2 = dict(self.netC.named_parameters()) + + for k in par1.keys(): + par1[k].data.mul_(decay).add_(1 - decay, par2[k].data) def update_learning_rate(self): super(ProGanModel, self).update_learning_rate() From 1ffea2289a200c7963f5e7b5f62f6d7e9a4c4b21 Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Sat, 25 Jul 2020 13:09:00 +0300 Subject: [PATCH 09/19] Refactor --- models/networks.py | 45 ++++++++++++++++++++++++------------------ models/progan_model.py | 26 ++++++++++++++---------- 2 files changed, 42 insertions(+), 29 deletions(-) diff --git a/models/networks.py b/models/networks.py index c5ca793cb6b..2cc4fbdc97d 100644 --- a/models/networks.py +++ b/models/networks.py @@ -118,7 +118,7 @@ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], init_weights_= def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], - init_weights=True): + init_weights=True, **kwargs): """Create a generator Parameters: @@ -157,14 +157,14 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in elif netG == 'unet_256': net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) elif netG == 'progan': - net = GeneratorProGan(input_code_dim=input_nc, in_channel=ngf) + net = GeneratorProGan(input_code_dim=input_nc, in_channel=ngf, max_steps=kwargs['max_steps'], out_channels=output_nc) else: raise NotImplementedError('Generator model name [%s] is not recognized' % netG) return init_net(net, init_type, init_gain, gpu_ids, init_weights_=init_weights) def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[], - init_weights=True): + init_weights=True, **kwargs): """Create a discriminator Parameters: @@ -204,7 +204,7 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal' elif netD == 'pixel': # classify if each pixel is real or fake net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) elif netD == 'progan': - net = DiscriminatorProGan(feat_dim=ndf, in_dim=input_nc) + net = DiscriminatorProGan(feat_dim=ndf, in_dim=input_nc, max_steps=kwargs['max_steps']) else: raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) return init_net(net, init_type, init_gain, gpu_ids, init_weights_=init_weights) @@ -740,10 +740,11 @@ def upscale(feat): class GeneratorProGan(nn.Module): - def __init__(self, input_code_dim=128, in_channel=128, pixel_norm=True, tanh=True): + def __init__(self, input_code_dim=128, in_channel=128, pixel_norm=True, tanh=True, max_steps=6, out_channels=3): super().__init__() - self.input_dim = input_code_dim + self.z_dim = input_code_dim self.tanh = tanh + self.max_steps = max_steps # self.input_layer = nn.Sequential( # EqualConvTranspose2d(input_code_dim, in_channel, 4, 1, 0), # PixelNorm(), @@ -765,15 +766,15 @@ def __init__(self, input_code_dim=128, in_channel=128, pixel_norm=True, tanh=Tru self.progression_64 = ConvBlockProGan(in_channel, in_channel // 2, 3, 1, pixel_norm=pixel_norm) self.progression_128 = ConvBlockProGan(in_channel // 2, in_channel // 4, 3, 1, pixel_norm=pixel_norm) self.progression_256 = ConvBlockProGan(in_channel // 4, in_channel // 4, 3, 1, pixel_norm=pixel_norm) + self.progression_512 = ConvBlockProGan(in_channel // 4, in_channel // 8, 3, 1, pixel_norm=pixel_norm) - self.to_rgb_8 = EqualConv2d(in_channel, 3, 1) - self.to_rgb_16 = EqualConv2d(in_channel, 3, 1) - self.to_rgb_32 = EqualConv2d(in_channel, 3, 1) - self.to_rgb_64 = EqualConv2d(in_channel // 2, 3, 1) - self.to_rgb_128 = EqualConv2d(in_channel // 4, 3, 1) - self.to_rgb_256 = EqualConv2d(in_channel // 4, 3, 1) - - self.max_step = 6 + self.to_rgb_8 = EqualConv2d(in_channel, out_channels, 1) + self.to_rgb_16 = EqualConv2d(in_channel, out_channels, 1) + self.to_rgb_32 = EqualConv2d(in_channel, out_channels, 1) + self.to_rgb_64 = EqualConv2d(in_channel // 2, out_channels, 1) + self.to_rgb_128 = EqualConv2d(in_channel // 4, out_channels, 1) + self.to_rgb_256 = EqualConv2d(in_channel // 4, out_channels, 1) + self.to_rgb_512 = EqualConv2d(in_channel // 8, out_channels, 1) def progress(self, feat, module): out = F.interpolate(feat, scale_factor=2, mode='bilinear', align_corners=False) @@ -791,8 +792,8 @@ def output(self, feat1, feat2, module1, module2, alpha): return out def forward(self, input, step=0, alpha=-1): - if step > self.max_step: - step = self.max_step + if step > self.max_steps: + step = self.max_steps # out_4 = self.input_layer(input.view(-1, self.input_dim, 1, 1)) out_4 = self.input_layer(input) out_4 = self.progression_4(out_4) @@ -822,12 +823,17 @@ def forward(self, input, step=0, alpha=-1): if step == 6: return self.output(out_128, out_256, self.to_rgb_128, self.to_rgb_256, alpha) + out_512 = self.progress(out_256, self.progression_512) + if step == 7: + return self.output(out_256, out_512, self.to_rgb_256, self.to_rgb_512, alpha) + class DiscriminatorProGan(nn.Module): - def __init__(self, feat_dim=128, in_dim=3): + def __init__(self, feat_dim=128, in_dim=3, max_steps=6): super().__init__() - + self.max_steps = max_steps self.progression = nn.ModuleList([ConvBlockProGan(feat_dim // 4, feat_dim // 4, 3, 1), + ConvBlockProGan(feat_dim // 4, feat_dim // 4, 3, 1), ConvBlockProGan(feat_dim // 4, feat_dim // 2, 3, 1), ConvBlockProGan(feat_dim // 2, feat_dim, 3, 1), ConvBlockProGan(feat_dim, feat_dim, 3, 1), @@ -836,6 +842,7 @@ def __init__(self, feat_dim=128, in_dim=3): ConvBlockProGan(feat_dim + 1, feat_dim, 3, 1, 4, 0)]) self.from_rgb = nn.ModuleList([EqualConv2d(in_dim, feat_dim // 4, 1), + EqualConv2d(in_dim, feat_dim // 4, 1), EqualConv2d(in_dim, feat_dim // 4, 1), EqualConv2d(in_dim, feat_dim // 2, 1), EqualConv2d(in_dim, feat_dim, 1), @@ -857,7 +864,7 @@ def forward(self, input, step=0, alpha=-1): if i == 0: out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-8) mean_std = out_std.mean() - mean_std = mean_std.expand(out.size(0), 1, 4, 4) + mean_std = mean_std.expand(out.size(0), 1, out_std.size(1), out_std.size(2)) out = torch.cat([out, mean_std], 1) out = self.progression[index](out) diff --git a/models/progan_model.py b/models/progan_model.py index 56ba8e4a4b1..0a539d5993d 100644 --- a/models/progan_model.py +++ b/models/progan_model.py @@ -40,20 +40,24 @@ def __init__(self, opt): self.model_names = ['G', 'D', 'C'] else: # during test time, only load G self.model_names = ['G'] + + if self.isTrain: + assert opt.crop_size == 4 * 2 ** self.max_steps + assert opt.beta1 == 0 + # define networks (both generator and discriminator) - self.netG = networks.define_G(opt.z_dim, opt.output_nc, opt.ngf, opt.netG, opt.norm, + self.netG = networks.define_G(opt.z_dim, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, - init_weights=False) - self.netC = networks.define_G(opt.z_dim, opt.output_nc, opt.ngf, opt.netG, opt.norm, - not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, - init_weights=False) + init_weights=False, max_steps=self.max_steps) if self.isTrain: + self.netC = networks.define_G(opt.z_dim, opt.input_nc, opt.ngf, opt.netG, opt.norm, + not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, + init_weights=False, max_steps=self.max_steps) self.netD = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, - init_weights=False) + init_weights=False, max_steps=self.max_steps) - assert opt.beta1 == 0 if self.isTrain: # define loss functions self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) @@ -76,8 +80,9 @@ def __init__(self, opt): self.iter = 0 self.alpha = 0. - assert self.total_steps > 12 - assert self.opt.crop_size % 2 ** self.max_steps == 0 + if self.isTrain: + assert self.total_steps > 12 + assert self.opt.crop_size % 2 ** self.max_steps == 0 # set fusing self.netG.eval() @@ -98,11 +103,12 @@ def set_input(self, input): def forward(self): """Run forward pass; called by both functions and .""" + net = self.netC if self.isTrain else self.netG batch_size = self.real_B.size(0) z = torch.randn((batch_size, self.z_dim, self.opt.crop_size // (2 ** self.max_steps), self.opt.crop_size // (2 ** self.max_steps)), device=self.device) - self.fake_B = self.netC(z, step=self.step, alpha=self.alpha) + self.fake_B = net(z, step=self.step, alpha=self.alpha) def backward_D(self): """Calculate GAN loss for the discriminator""" From cadd1fdd6aa5a95c03ecc6a09c47ef9cfd1d4b1a Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Sat, 25 Jul 2020 13:11:52 +0300 Subject: [PATCH 10/19] Fix dim --- models/networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/networks.py b/models/networks.py index 2cc4fbdc97d..8d6b6d742f9 100644 --- a/models/networks.py +++ b/models/networks.py @@ -832,7 +832,7 @@ class DiscriminatorProGan(nn.Module): def __init__(self, feat_dim=128, in_dim=3, max_steps=6): super().__init__() self.max_steps = max_steps - self.progression = nn.ModuleList([ConvBlockProGan(feat_dim // 4, feat_dim // 4, 3, 1), + self.progression = nn.ModuleList([ConvBlockProGan(feat_dim // 8, feat_dim // 4, 3, 1), ConvBlockProGan(feat_dim // 4, feat_dim // 4, 3, 1), ConvBlockProGan(feat_dim // 4, feat_dim // 2, 3, 1), ConvBlockProGan(feat_dim // 2, feat_dim, 3, 1), @@ -841,7 +841,7 @@ def __init__(self, feat_dim=128, in_dim=3, max_steps=6): ConvBlockProGan(feat_dim, feat_dim, 3, 1), ConvBlockProGan(feat_dim + 1, feat_dim, 3, 1, 4, 0)]) - self.from_rgb = nn.ModuleList([EqualConv2d(in_dim, feat_dim // 4, 1), + self.from_rgb = nn.ModuleList([EqualConv2d(in_dim, feat_dim // 8, 1), EqualConv2d(in_dim, feat_dim // 4, 1), EqualConv2d(in_dim, feat_dim // 4, 1), EqualConv2d(in_dim, feat_dim // 2, 1), From 6b95f6de5a21b6f419f189b3b778236a62162c29 Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Sat, 25 Jul 2020 13:26:43 +0300 Subject: [PATCH 11/19] Doc --- models/progan_model.py | 31 +++++++++++++++++++++++++++++-- scripts/train_progan.sh | 2 ++ 2 files changed, 31 insertions(+), 2 deletions(-) create mode 100755 scripts/train_progan.sh diff --git a/models/progan_model.py b/models/progan_model.py index 0a539d5993d..42c4937900d 100644 --- a/models/progan_model.py +++ b/models/progan_model.py @@ -14,8 +14,14 @@ class ProGanModel(BaseModel): """ This is an implementation of the paper "Progressive Growing of GANs": https://arxiv.org/abs/1710.10196. Model requires dataset of type dataset_mode='single', generator netG='progan', discriminator netD='progan'. - ngf and ndf controlls dimensions of the backbone. - Network G is a master-generator (for eval) and network C (stands for current) is a current trainable generator. + Please note that opt.crop_size (default 256) == 4 * 2 ** opt.max_steps (default max_steps is 6). + ngf and ndf controlls dimensions of the backbone (128-512). + Network G is a master-generator (accumulates weights for eval) and network C (stands for current) is a + current trainable generator. + + See also: + https://github.com/tkarras/progressive_growing_of_gans + https://github.com/odegeasslbc/Progressive-GAN-pytorch """ @staticmethod @@ -49,14 +55,23 @@ def __init__(self, opt): self.netG = networks.define_G(opt.z_dim, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, init_weights=False, max_steps=self.max_steps) + """ + resulting generator, not for training, just for eval + """ if self.isTrain: self.netC = networks.define_G(opt.z_dim, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, init_weights=False, max_steps=self.max_steps) + """ + current training generator + """ self.netD = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, init_weights=False, max_steps=self.max_steps) + """ + current training discr. + """ if self.isTrain: # define loss functions @@ -76,9 +91,21 @@ def __init__(self, opt): # inner counters self.total_steps = opt.n_epochs + opt.n_epochs_decay + 1 + """ + total epochs + """ self.step = 1 + """ + current step of network, 1-6 + """ self.iter = 0 + """ + current iter, 0-(total epochs)//6 + """ self.alpha = 0. + """ + current alpha rate to fuse different scales + """ if self.isTrain: assert self.total_steps > 12 diff --git a/scripts/train_progan.sh b/scripts/train_progan.sh new file mode 100755 index 00000000000..861c69e2585 --- /dev/null +++ b/scripts/train_progan.sh @@ -0,0 +1,2 @@ +set -ex +python train.py --dataroot ./datasets/facades --name progan_facades --model progan --pool_size 10 --dataset_mode single --batch_size 1 --crop_size 256 --preprocess crop From 19bce11b10c408ea20e2388143628c7b9e062200 Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Sun, 26 Jul 2020 01:39:11 +0300 Subject: [PATCH 12/19] Fix forward --- models/progan_model.py | 4 +++- scripts/train_progan.sh | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/models/progan_model.py b/models/progan_model.py index 42c4937900d..d1445d0ef8c 100644 --- a/models/progan_model.py +++ b/models/progan_model.py @@ -131,11 +131,13 @@ def set_input(self, input): def forward(self): """Run forward pass; called by both functions and .""" net = self.netC if self.isTrain else self.netG + step = self.step if self.isTrain else self.max_steps + alpha = self.alpha if self.isTrain else 1 batch_size = self.real_B.size(0) z = torch.randn((batch_size, self.z_dim, self.opt.crop_size // (2 ** self.max_steps), self.opt.crop_size // (2 ** self.max_steps)), device=self.device) - self.fake_B = net(z, step=self.step, alpha=self.alpha) + self.fake_B = net(z, step=step, alpha=alpha) def backward_D(self): """Calculate GAN loss for the discriminator""" diff --git a/scripts/train_progan.sh b/scripts/train_progan.sh index 861c69e2585..aedbf5c4bfd 100755 --- a/scripts/train_progan.sh +++ b/scripts/train_progan.sh @@ -1,2 +1,2 @@ set -ex -python train.py --dataroot ./datasets/facades --name progan_facades --model progan --pool_size 10 --dataset_mode single --batch_size 1 --crop_size 256 --preprocess crop +python train.py --dataroot ./datasets/facades --name progan_train --model progan --pool_size 50 --dataset_mode single --batch_size 1 --crop_size 256 --load_size 300 --preprocess resize_and_crop From b5825d43fa6f42018c7a91704bd19b300b7ed01b Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Mon, 27 Jul 2020 11:15:07 +0300 Subject: [PATCH 13/19] Fix apex epsilon --- models/networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/networks.py b/models/networks.py index 8d6b6d742f9..d35768acbf2 100644 --- a/models/networks.py +++ b/models/networks.py @@ -862,7 +862,7 @@ def forward(self, input, step=0, alpha=-1): out = self.from_rgb[index](input) if i == 0: - out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-8) + out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-6) mean_std = out_std.mean() mean_std = mean_std.expand(out.size(0), 1, out_std.size(1), out_std.size(2)) out = torch.cat([out, mean_std], 1) From 8df0d9ce8280073f24accf243c4224617c4beb21 Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Mon, 27 Jul 2020 15:15:44 +0300 Subject: [PATCH 14/19] Add wgangp loss to progan --- models/progan_model.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/models/progan_model.py b/models/progan_model.py index d1445d0ef8c..6a7a38278fa 100644 --- a/models/progan_model.py +++ b/models/progan_model.py @@ -1,9 +1,12 @@ import torch +from torch.autograd import grad from .base_model import BaseModel from . import networks from torch.nn import functional as F +from .networks import cal_gradient_penalty + try: from apex import amp except ImportError: @@ -27,7 +30,7 @@ class ProGanModel(BaseModel): @staticmethod def modify_commandline_options(parser, is_train=True): - parser.set_defaults(netG='progan', netD='progan', dataset_mode='single', beta1=0., ngf=512, ndf=512) + parser.set_defaults(netG='progan', netD='progan', dataset_mode='single', beta1=0., ngf=512, ndf=512, gan_mode='wgangp') parser.add_argument('--z_dim', type=int, default=32, help='random noise dim') parser.add_argument('--max_steps', type=int, default=6, help='steps of growing') return parser @@ -38,7 +41,9 @@ def __init__(self, opt): self.z_dim = opt.z_dim self.max_steps = opt.max_steps # specify the training losses you want to print out. The training/test scripts will call - self.loss_names = ['G_GAN', 'D_real', 'D_fake'] + self.loss_names = ['G_GAN', 'D_real', 'D_fake', 'D'] + if self.opt.gan_mode == 'wgangp': + self.loss_names.append('D_gradpen') # specify the images you want to save/display. The training/test scripts will call self.visual_names = ['fake_B', 'real_B'] # specify the models you want to save to the disk. The training/test scripts will call and @@ -149,8 +154,24 @@ def backward_D(self): real_B = self.real_B pred_real = self.netD(real_B, step=self.step, alpha=self.alpha) self.loss_D_real = self.criterionGAN(pred_real, True) + if self.opt.gan_mode == 'wgangp': + # some correction of D loss + self.loss_D_real += 0.001 * (pred_real ** 2).mean() # combine loss and calculate gradients self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 + if self.opt.gan_mode == 'wgangp': + ### gradient penalty for D + b_size = fake_B.size(0) + eps = torch.rand(b_size, 1, 1, 1, dtype=fake_B.dtype, device=fake_B.device).to(fake_B.device) + x_hat = eps * real_B.data + (1 - eps) * fake_B.detach().data + x_hat.requires_grad = True + hat_predict = self.netD(x_hat, step=self.step, alpha=self.alpha) + grad_x_hat = grad( + outputs=hat_predict.sum(), inputs=x_hat, create_graph=True)[0] + grad_penalty = ((grad_x_hat.view(grad_x_hat.size(0), -1) + .norm(2, dim=1) - 1) ** 2).mean() + self.loss_D_gradpen = 10 * grad_penalty + self.loss_D += self.loss_D_gradpen if self.opt.apex: with amp.scale_loss(self.loss_D, self.optimizer_D, loss_id=0) as loss_D_scaled: From 1f58ab7df713e983b302af39eaf142a98bade1e8 Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Mon, 27 Jul 2020 18:27:25 +0300 Subject: [PATCH 15/19] Add steps scheduling --- models/progan_model.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/models/progan_model.py b/models/progan_model.py index 6a7a38278fa..bdca862a656 100644 --- a/models/progan_model.py +++ b/models/progan_model.py @@ -4,6 +4,7 @@ from .base_model import BaseModel from . import networks from torch.nn import functional as F +import numpy as np from .networks import cal_gradient_penalty @@ -30,9 +31,12 @@ class ProGanModel(BaseModel): @staticmethod def modify_commandline_options(parser, is_train=True): - parser.set_defaults(netG='progan', netD='progan', dataset_mode='single', beta1=0., ngf=512, ndf=512, gan_mode='wgangp') + parser.set_defaults(netG='progan', netD='progan', dataset_mode='single', beta1=0., ngf=512, ndf=512, + gan_mode='wgangp') parser.add_argument('--z_dim', type=int, default=32, help='random noise dim') parser.add_argument('--max_steps', type=int, default=6, help='steps of growing') + parser.add_argument('--steps_schedule', type=str, default='linear', + help='type of when to turn to the next step: linear or fibonacci') return parser def __init__(self, opt): @@ -111,6 +115,10 @@ def __init__(self, opt): """ current alpha rate to fuse different scales """ + self.epochs_schedule = self.create_epochs_schedule(opt.steps_schedule) + """ + schedule when to turn to the next step + """ if self.isTrain: assert self.total_steps > 12 @@ -207,13 +215,25 @@ def optimize_parameters(self): self.optimizer_C.step() # udpate G's weights self.accumulate() # fuse params + def create_epochs_schedule(self, steps_schedule_type): + if steps_schedule_type == 'fibonacci': + basic_weights = np.array([0., 1., 2., 3., 5., 8., 13., 21.]) + else: + basic_weights = np.array([0., 1, 1, 1, 1, 1, 1, 1]) + basic_weights = basic_weights[:self.max_steps + 1] + epochs_schedule = self.total_steps * basic_weights / np.sum(basic_weights) + epochs_schedule = epochs_schedule.astype(np.int) + print('schedule of step turning: %s' % str(epochs_schedule)) + return epochs_schedule + def update_inners_counters(self): """ Update counters of iterations """ self.iter += 1 - self.alpha = min(1, (2 / (self.total_steps // self.max_steps)) * self.iter) - if self.iter > self.total_steps // self.max_steps: + self.alpha = min(1, (2. / (self.epochs_schedule[self.step])) * self.iter) + if self.iter > self.epochs_schedule[self.step]: + print('turn to step % s' % str(self.step + 1)) self.alpha = 0 self.iter = 0 self.step += 1 From 3c8b00aaf93b1ddde3bea9aaf574177bf73cf01d Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Mon, 27 Jul 2020 18:34:30 +0300 Subject: [PATCH 16/19] Add steps scheduling --- models/progan_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/progan_model.py b/models/progan_model.py index bdca862a656..5b9661edfdf 100644 --- a/models/progan_model.py +++ b/models/progan_model.py @@ -217,9 +217,9 @@ def optimize_parameters(self): def create_epochs_schedule(self, steps_schedule_type): if steps_schedule_type == 'fibonacci': - basic_weights = np.array([0., 1., 2., 3., 5., 8., 13., 21.]) + basic_weights = np.array([0., 1., 2., 3., 5., 8., 13., 21., 34.]) else: - basic_weights = np.array([0., 1, 1, 1, 1, 1, 1, 1]) + basic_weights = np.array([0., 1, 1, 1, 1, 1, 1, 1, 1]) basic_weights = basic_weights[:self.max_steps + 1] epochs_schedule = self.total_steps * basic_weights / np.sum(basic_weights) epochs_schedule = epochs_schedule.astype(np.int) From 620dd5ca652a5f6cb893946ff9dfa6cfc1261f3e Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Tue, 28 Jul 2020 15:14:37 +0300 Subject: [PATCH 17/19] New ProGan model --- models/networks.py | 387 ++++++++++++------------------ models/progan_layers.py | 517 ++++++++++++++++++++++++++++++++++++++++ models/progan_model.py | 56 ++++- 3 files changed, 719 insertions(+), 241 deletions(-) create mode 100644 models/progan_layers.py diff --git a/models/networks.py b/models/networks.py index d35768acbf2..2c9989819f3 100644 --- a/models/networks.py +++ b/models/networks.py @@ -2,9 +2,10 @@ import torch.nn as nn from torch.nn import init import functools + +from torch.nn.functional import interpolate from torch.optim import lr_scheduler -from torch.nn import functional as F -from math import sqrt +import numpy as np ############################################################################### # Helper Functions @@ -157,7 +158,8 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in elif netG == 'unet_256': net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) elif netG == 'progan': - net = GeneratorProGan(input_code_dim=input_nc, in_channel=ngf, max_steps=kwargs['max_steps'], out_channels=output_nc) + net = GeneratorProGanV2(input_code_dim=input_nc, in_channel=ngf, max_steps=kwargs['max_steps']+1, out_channels=output_nc) + # net = GeneratorProGan(input_code_dim=input_nc, in_channel=ngf, max_steps=kwargs['max_steps'], out_channels=output_nc) else: raise NotImplementedError('Generator model name [%s] is not recognized' % netG) return init_net(net, init_type, init_gain, gpu_ids, init_weights_=init_weights) @@ -204,7 +206,8 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal' elif netD == 'pixel': # classify if each pixel is real or fake net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) elif netD == 'progan': - net = DiscriminatorProGan(feat_dim=ndf, in_dim=input_nc, max_steps=kwargs['max_steps']) + net = DiscriminatorProGanV2(feat_dim=ndf, in_dim=input_nc, max_steps=kwargs['max_steps']+1) + # net = DiscriminatorProGan(feat_dim=ndf, in_dim=input_nc, max_steps=kwargs['max_steps']) else: raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) return init_net(net, init_type, init_gain, gpu_ids, init_weights_=init_weights) @@ -622,265 +625,193 @@ def forward(self, input): return self.net(input) -class EqualLR: - def __init__(self, name): - self.name = name - - def compute_weight(self, module): - weight = getattr(module, self.name + '_orig') - fan_in = weight.data.size(1) * weight.data[0][0].numel() +# ======================================================================================== +# Generator Module of ProGAN +# can be used with ProGAN or standalone (for inference) +# Thanks to https://raw.githubusercontent.com/akanimax/pro_gan_pytorch/ +# ======================================================================================== - return weight * sqrt(2 / fan_in) - @staticmethod - def apply(module, name): - fn = EqualLR(name) +class GeneratorProGanV2(nn.Module): + """ Generator of the GAN network """ - weight = getattr(module, name) - del module._parameters[name] - module.register_parameter(name + '_orig', nn.Parameter(weight.data)) - module.register_forward_pre_hook(fn) + def __init__(self, max_steps=7, input_code_dim=512, in_channel=512, out_channels=3, use_eql=True): + """ + constructor for the Generator class + :param max_steps: required depth of the Network + :param input_code_dim: size of the latent manifold + :param use_eql: whether to use equalized learning rate + """ + from .progan_layers import GenGeneralConvBlock, GenInitialBlock, _equalized_conv2d - return fn + super(GeneratorProGanV2, self).__init__() - def __call__(self, module, input): - weight = self.compute_weight(module) - setattr(module, self.name, weight) + assert input_code_dim != 0 and ((input_code_dim & (input_code_dim - 1)) == 0), \ + "latent size not a power of 2" + if max_steps >= 4: + assert in_channel >= np.power(2, max_steps - 4), "in_channel size will diminish to zero" + # state of the generator: + self.use_eql = use_eql + self.depth = max_steps + self.latent_size = input_code_dim + self.channels_conv = in_channel -def equal_lr(module, name='weight'): - EqualLR.apply(module, name) + # register the modules required for the GAN + self.initial_block = GenInitialBlock(in_channels=self.latent_size, out_channels=in_channel, use_eql=self.use_eql) - return module + # create a module list of the other required general convolution blocks + self.layers = nn.ModuleList([]) # initialize to empty list + # create the ToRGB layers for various outputs: + if self.use_eql: + self.toRGB = lambda in_channels: \ + _equalized_conv2d(in_channels, out_channels, (1, 1), bias=True) + else: + from torch.nn import Conv2d + self.toRGB = lambda in_channels: Conv2d(in_channels, out_channels, (1, 1), bias=True) -class PixelNorm(nn.Module): - def __init__(self): - super().__init__() + self.rgb_converters = nn.ModuleList([self.toRGB(self.channels_conv)]) - def forward(self, input): - return input / torch.sqrt(torch.mean(input ** 2, dim=1, keepdim=True) - + 1e-8) + # create the remaining layers + for i in range(self.depth - 1): + if i <= 2: + layer = GenGeneralConvBlock(self.channels_conv, + self.channels_conv, use_eql=self.use_eql) + rgb = self.toRGB(self.channels_conv) + else: + layer = GenGeneralConvBlock( + int(self.channels_conv // np.power(2, i - 3)), + int(self.channels_conv // np.power(2, i - 2)), + use_eql=self.use_eql + ) + rgb = self.toRGB(int(self.channels_conv // np.power(2, i - 2))) + self.layers.append(layer) + self.rgb_converters.append(rgb) + + # register the temporary upsampler + self.temporaryUpsampler = lambda x: interpolate(x, scale_factor=2) + + def forward(self, x, step, alpha): + """ + forward pass of the Generator + :param x: input noise + :param step: current depth from where output is required + :param alpha: value of alpha for fade-in effect + :return: y => output + """ + # step = step - 1 + assert step < self.depth, "Requested output depth cannot be produced" + y = self.initial_block(x) -class EqualConv2d(nn.Module): - def __init__(self, *args, **kwargs): - super().__init__() + if step > 0: + for block in self.layers[:step - 1]: + y = block(y) - conv = nn.Conv2d(*args, **kwargs) - conv.weight.data.normal_() - conv.bias.data.zero_() - self.conv = equal_lr(conv) + residual = self.rgb_converters[step - 1](self.temporaryUpsampler(y)) + straight = self.rgb_converters[step](self.layers[step - 1](y)) - def forward(self, input): - return self.conv(input) + out = (alpha * straight) + ((1 - alpha) * residual) + else: + out = self.rgb_converters[0](y) -class EqualConvTranspose2d(nn.Module): - ### additional module for OOGAN usage - def __init__(self, *args, **kwargs): - super().__init__() + return out - conv = nn.ConvTranspose2d(*args, **kwargs) - conv.weight.data.normal_() - conv.bias.data.zero_() - self.conv = equal_lr(conv) +# ======================================================================================== +# Discriminator Module of ProGAN +# can be used with ProGAN or standalone (for inference). +# Thanks to https://raw.githubusercontent.com/akanimax/pro_gan_pytorch/ +# ======================================================================================== - def forward(self, input): - return self.conv(input) +class DiscriminatorProGanV2(nn.Module): + """ Discriminator of the GAN """ -class EqualLinear(nn.Module): - def __init__(self, in_dim, out_dim): - super().__init__() + def __init__(self, max_steps=7, feat_dim=512, in_dim=3, use_eql=True): + """ + constructor for the class + :param max_steps: total height of the discriminator (Must be equal to the Generator depth) + :param feat_dim: size of the deepest features extracted + (Must be equal to Generator latent_size) + :param use_eql: whether to use equalized learning rate + """ + from torch.nn import ModuleList, AvgPool2d + from .progan_layers import DisGeneralConvBlock, DisFinalBlock, _equalized_conv2d - linear = nn.Linear(in_dim, out_dim) - linear.weight.data.normal_() - linear.bias.data.zero_() + super(DiscriminatorProGanV2, self).__init__() - self.linear = equal_lr(linear) + assert feat_dim != 0 and ((feat_dim & (feat_dim - 1)) == 0), \ + "latent size not a power of 2" + if max_steps >= 4: + assert feat_dim >= np.power(2, max_steps - 4), "feature size cannot be produced" - def forward(self, input): - return self.linear(input) + # create state of the object + self.use_eql = use_eql + self.height = max_steps + self.feature_size = feat_dim + self.final_block = DisFinalBlock(self.feature_size, use_eql=self.use_eql) -class ConvBlockProGan(nn.Module): - def __init__(self, in_channel, out_channel, kernel_size, padding, kernel_size2=None, padding2=None, - pixel_norm=True): - super().__init__() + # create a module list of the other required general convolution blocks + self.layers = ModuleList([]) # initialize to empty list - pad1 = padding - pad2 = padding - if padding2 is not None: - pad2 = padding2 + # create the fromRGB layers for various inputs: + if self.use_eql: + self.fromRGB = lambda out_channels: \ + _equalized_conv2d(3, out_channels, (1, 1), bias=True) + else: + from torch.nn import Conv2d + self.fromRGB = lambda out_channels: Conv2d(in_dim, out_channels, (1, 1), bias=True) + + self.rgb_to_features = ModuleList([self.fromRGB(self.feature_size)]) + + # create the remaining layers + for i in range(self.height - 1): + if i > 2: + layer = DisGeneralConvBlock( + int(self.feature_size // np.power(2, i - 2)), + int(self.feature_size // np.power(2, i - 3)), + use_eql=self.use_eql + ) + rgb = self.fromRGB(int(self.feature_size // np.power(2, i - 2))) + else: + layer = DisGeneralConvBlock(self.feature_size, + self.feature_size, use_eql=self.use_eql) + rgb = self.fromRGB(self.feature_size) - kernel1 = kernel_size - kernel2 = kernel_size - if kernel_size2 is not None: - kernel2 = kernel_size2 + self.layers.append(layer) + self.rgb_to_features.append(rgb) - convs = [EqualConv2d(in_channel, out_channel, kernel1, padding=pad1)] - if pixel_norm: - convs.append(PixelNorm()) - convs.append(nn.LeakyReLU(0.1)) - convs.append(EqualConv2d(out_channel, out_channel, kernel2, padding=pad2)) - if pixel_norm: - convs.append(PixelNorm()) - convs.append(nn.LeakyReLU(0.1)) + # register the temporary downSampler + self.temporaryDownsampler = AvgPool2d(2) - self.conv = nn.Sequential(*convs) + def forward(self, x, step, alpha): + """ + forward pass of the discriminator + :param x: input to the network + :param step: current height of operation (Progressive GAN) + :param alpha: current value of alpha for fade-in + :return: out => raw prediction values (WGAN-GP) + """ + # step = step - 1 + assert step < self.height, "Requested output depth cannot be produced" - def forward(self, input): - out = self.conv(input) - return out + if step > 0: + residual = self.rgb_to_features[step - 1](self.temporaryDownsampler(x)) + straight = self.layers[step - 1]( + self.rgb_to_features[step](x) + ) -def upscale(feat): - return F.interpolate(feat, scale_factor=2, mode='bilinear', align_corners=False) - - -class GeneratorProGan(nn.Module): - def __init__(self, input_code_dim=128, in_channel=128, pixel_norm=True, tanh=True, max_steps=6, out_channels=3): - super().__init__() - self.z_dim = input_code_dim - self.tanh = tanh - self.max_steps = max_steps - # self.input_layer = nn.Sequential( - # EqualConvTranspose2d(input_code_dim, in_channel, 4, 1, 0), - # PixelNorm(), - # nn.LeakyReLU(0.1)) - - self.input_layer = nn.Sequential( - EqualConv2d(input_code_dim, input_code_dim, 3, padding=1), - PixelNorm(), - nn.LeakyReLU(0.1), - EqualConv2d(input_code_dim, in_channel, 3, padding=1), - PixelNorm(), - nn.LeakyReLU(0.1) - ) - - self.progression_4 = ConvBlockProGan(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm) - self.progression_8 = ConvBlockProGan(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm) - self.progression_16 = ConvBlockProGan(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm) - self.progression_32 = ConvBlockProGan(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm) - self.progression_64 = ConvBlockProGan(in_channel, in_channel // 2, 3, 1, pixel_norm=pixel_norm) - self.progression_128 = ConvBlockProGan(in_channel // 2, in_channel // 4, 3, 1, pixel_norm=pixel_norm) - self.progression_256 = ConvBlockProGan(in_channel // 4, in_channel // 4, 3, 1, pixel_norm=pixel_norm) - self.progression_512 = ConvBlockProGan(in_channel // 4, in_channel // 8, 3, 1, pixel_norm=pixel_norm) - - self.to_rgb_8 = EqualConv2d(in_channel, out_channels, 1) - self.to_rgb_16 = EqualConv2d(in_channel, out_channels, 1) - self.to_rgb_32 = EqualConv2d(in_channel, out_channels, 1) - self.to_rgb_64 = EqualConv2d(in_channel // 2, out_channels, 1) - self.to_rgb_128 = EqualConv2d(in_channel // 4, out_channels, 1) - self.to_rgb_256 = EqualConv2d(in_channel // 4, out_channels, 1) - self.to_rgb_512 = EqualConv2d(in_channel // 8, out_channels, 1) - - def progress(self, feat, module): - out = F.interpolate(feat, scale_factor=2, mode='bilinear', align_corners=False) - out = module(out) - return out + y = (alpha * straight) + ((1 - alpha) * residual) - def output(self, feat1, feat2, module1, module2, alpha): - if 0 <= alpha < 1: - skip_rgb = upscale(module1(feat1)) - out = (1 - alpha) * skip_rgb + alpha * module2(feat2) + for block in reversed(self.layers[:step - 1]): + y = block(y) else: - out = module2(feat2) - if self.tanh: - return torch.tanh(out) - return out + y = self.rgb_to_features[0](x) - def forward(self, input, step=0, alpha=-1): - if step > self.max_steps: - step = self.max_steps - # out_4 = self.input_layer(input.view(-1, self.input_dim, 1, 1)) - out_4 = self.input_layer(input) - out_4 = self.progression_4(out_4) - out_8 = self.progress(out_4, self.progression_8) - if step == 1: - if self.tanh: - return torch.tanh(self.to_rgb_8(out_8)) - return self.to_rgb_8(out_8) - - out_16 = self.progress(out_8, self.progression_16) - if step == 2: - return self.output(out_8, out_16, self.to_rgb_8, self.to_rgb_16, alpha) - - out_32 = self.progress(out_16, self.progression_32) - if step == 3: - return self.output(out_16, out_32, self.to_rgb_16, self.to_rgb_32, alpha) - - out_64 = self.progress(out_32, self.progression_64) - if step == 4: - return self.output(out_32, out_64, self.to_rgb_32, self.to_rgb_64, alpha) - - out_128 = self.progress(out_64, self.progression_128) - if step == 5: - return self.output(out_64, out_128, self.to_rgb_64, self.to_rgb_128, alpha) - - out_256 = self.progress(out_128, self.progression_256) - if step == 6: - return self.output(out_128, out_256, self.to_rgb_128, self.to_rgb_256, alpha) - - out_512 = self.progress(out_256, self.progression_512) - if step == 7: - return self.output(out_256, out_512, self.to_rgb_256, self.to_rgb_512, alpha) - - -class DiscriminatorProGan(nn.Module): - def __init__(self, feat_dim=128, in_dim=3, max_steps=6): - super().__init__() - self.max_steps = max_steps - self.progression = nn.ModuleList([ConvBlockProGan(feat_dim // 8, feat_dim // 4, 3, 1), - ConvBlockProGan(feat_dim // 4, feat_dim // 4, 3, 1), - ConvBlockProGan(feat_dim // 4, feat_dim // 2, 3, 1), - ConvBlockProGan(feat_dim // 2, feat_dim, 3, 1), - ConvBlockProGan(feat_dim, feat_dim, 3, 1), - ConvBlockProGan(feat_dim, feat_dim, 3, 1), - ConvBlockProGan(feat_dim, feat_dim, 3, 1), - ConvBlockProGan(feat_dim + 1, feat_dim, 3, 1, 4, 0)]) - - self.from_rgb = nn.ModuleList([EqualConv2d(in_dim, feat_dim // 8, 1), - EqualConv2d(in_dim, feat_dim // 4, 1), - EqualConv2d(in_dim, feat_dim // 4, 1), - EqualConv2d(in_dim, feat_dim // 2, 1), - EqualConv2d(in_dim, feat_dim, 1), - EqualConv2d(in_dim, feat_dim, 1), - EqualConv2d(in_dim, feat_dim, 1), - EqualConv2d(in_dim, feat_dim, 1)]) - - self.n_layer = len(self.progression) - - self.linear = EqualLinear(feat_dim, 1) - - def forward(self, input, step=0, alpha=-1): - for i in range(step, -1, -1): - index = self.n_layer - i - 1 - - if i == step: - out = self.from_rgb[index](input) - - if i == 0: - out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-6) - mean_std = out_std.mean() - mean_std = mean_std.expand(out.size(0), 1, out_std.size(1), out_std.size(2)) - out = torch.cat([out, mean_std], 1) - - out = self.progression[index](out) - - if i > 0: - # out = F.avg_pool2d(out, 2) - out = F.interpolate(out, scale_factor=0.5, mode='bilinear', align_corners=False) - - if i == step and 0 <= alpha < 1: - # skip_rgb = F.avg_pool2d(input, 2) - skip_rgb = F.interpolate(input, scale_factor=0.5, mode='bilinear', align_corners=False) - skip_rgb = self.from_rgb[index + 1](skip_rgb) - out = (1 - alpha) * skip_rgb + alpha * out - - out = out.squeeze(2).squeeze(2) - # print(input.size(), out.size(), step) - out = self.linear(out) + out = self.final_block(y) return out diff --git a/models/progan_layers.py b/models/progan_layers.py new file mode 100644 index 00000000000..1f0ca63c02f --- /dev/null +++ b/models/progan_layers.py @@ -0,0 +1,517 @@ +""" Module containing custom layers +Thanks to https://raw.githubusercontent.com/akanimax/pro_gan_pytorch/ +""" + +import torch as th + + +# extending Conv2D and Deconv2D layers for equalized learning rate logic +from torch.nn import Conv2d + + +class _equalized_conv2d(th.nn.Module): + """ conv2d with the concept of equalized learning rate + Args: + :param c_in: input channels + :param c_out: output channels + :param k_size: kernel size (h, w) should be a tuple or a single integer + :param stride: stride for conv + :param pad: padding + :param bias: whether to use bias or not + """ + + def __init__(self, c_in, c_out, k_size, stride=1, pad=0, bias=True): + """ constructor for the class """ + from torch.nn.modules.utils import _pair + from numpy import sqrt, prod + + super(_equalized_conv2d, self).__init__() + + # define the weight and bias if to be used + self.weight = th.nn.Parameter(th.nn.init.normal_( + th.empty(c_out, c_in, *_pair(k_size)) + )) + + self.use_bias = bias + self.stride = stride + self.pad = pad + + if self.use_bias: + self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0)) + + fan_in = prod(_pair(k_size)) * c_in # value of fan_in + self.scale = sqrt(2) / sqrt(fan_in) + + def forward(self, x): + """ + forward pass of the network + :param x: input + :return: y => output + """ + from torch.nn.functional import conv2d + + return conv2d(input=x, + weight=self.weight * self.scale, # scale the weight on runtime + bias=self.bias if self.use_bias else None, + stride=self.stride, + padding=self.pad) + + def extra_repr(self): + return ", ".join(map(str, self.weight.shape)) + + +class _equalized_deconv2d(th.nn.Module): + """ Transpose convolution using the equalized learning rate + Args: + :param c_in: input channels + :param c_out: output channels + :param k_size: kernel size + :param stride: stride for convolution transpose + :param pad: padding + :param bias: whether to use bias or not + """ + + def __init__(self, c_in, c_out, k_size, stride=1, pad=0, bias=True): + """ constructor for the class """ + from torch.nn.modules.utils import _pair + from numpy import sqrt + + super(_equalized_deconv2d, self).__init__() + + # define the weight and bias if to be used + self.weight = th.nn.Parameter(th.nn.init.normal_( + th.empty(c_in, c_out, *_pair(k_size)) + )) + + self.use_bias = bias + self.stride = stride + self.pad = pad + + if self.use_bias: + self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0)) + + fan_in = c_in # value of fan_in for deconv + self.scale = sqrt(2) / sqrt(fan_in) + + def forward(self, x): + """ + forward pass of the layer + :param x: input + :return: y => output + """ + from torch.nn.functional import conv_transpose2d + + return conv_transpose2d(input=x, + weight=self.weight * self.scale, # scale the weight on runtime + bias=self.bias if self.use_bias else None, + stride=self.stride, + padding=self.pad) + + def extra_repr(self): + return ", ".join(map(str, self.weight.shape)) + + +class _equalized_linear(th.nn.Module): + """ Linear layer using equalized learning rate + Args: + :param c_in: number of input channels + :param c_out: number of output channels + :param bias: whether to use bias with the linear layer + """ + + def __init__(self, c_in, c_out, bias=True): + """ + Linear layer modified for equalized learning rate + """ + from numpy import sqrt + + super(_equalized_linear, self).__init__() + + self.weight = th.nn.Parameter(th.nn.init.normal_( + th.empty(c_out, c_in) + )) + + self.use_bias = bias + + if self.use_bias: + self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0)) + + fan_in = c_in + self.scale = sqrt(2) / sqrt(fan_in) + + def forward(self, x): + """ + forward pass of the layer + :param x: input + :return: y => output + """ + from torch.nn.functional import linear + return linear(x, self.weight * self.scale, + self.bias if self.use_bias else None) + + +# ---------------------------------------------------------------------------- +# Pixelwise feature vector normalization. +# reference: https://github.com/tkarras/progressive_growing_of_gans/blob/master/networks.py#L120 +# ---------------------------------------------------------------------------- +class PixelwiseNorm(th.nn.Module): + def __init__(self): + super(PixelwiseNorm, self).__init__() + + def forward(self, x, alpha=1e-8): + """ + forward pass of the module + :param x: input activations volume + :param alpha: small number for numerical stability + :return: y => pixel normalized activations + """ + y = x.pow(2.).mean(dim=1, keepdim=True).add(alpha).sqrt() # [N1HW] + y = x / y # normalize the input x volume + return y + + +# ========================================================== +# Layers required for Building The generator and +# discriminator +# ========================================================== +class GenInitialBlock(th.nn.Module): + """ Module implementing the initial block of the input """ + + def __init__(self, in_channels, out_channels, use_eql): + """ + constructor for the inner class + :param in_channels: number of input channels to the block + :param use_eql: whether to use equalized learning rate + """ + from torch.nn import LeakyReLU + + super(GenInitialBlock, self).__init__() + + # EqualConv2d(input_code_dim, input_code_dim, 3, padding=1), + # PixelNorm(), + # nn.LeakyReLU(0.1), + # EqualConv2d(input_code_dim, in_channel, 3, padding=1), + # PixelNorm(), + # nn.LeakyReLU(0.1) + + if use_eql: + self.conv_1 = _equalized_conv2d(in_channels, out_channels, (3, 3), bias=True, pad=1) + self.conv_2 = _equalized_conv2d(out_channels, out_channels, (3, 3), bias=True, pad=1) + else: + self.conv_1 = Conv2d(in_channels, out_channels, (3, 3), padding=1, bias=True) + self.conv_2 = Conv2d(out_channels, out_channels, (3, 3), padding=1, bias=True) + # if use_eql: + # self.conv_1 = _equalized_deconv2d(in_channels, in_channels, (4, 4), bias=True) + # self.conv_2 = _equalized_conv2d(in_channels, in_channels, (3, 3), + # pad=1, bias=True) + # + # else: + # from torch.nn import Conv2d, ConvTranspose2d + # self.conv_1 = ConvTranspose2d(in_channels, in_channels, (4, 4), bias=True) + # self.conv_2 = Conv2d(in_channels, in_channels, (3, 3), padding=1, bias=True) + + # Pixelwise feature vector normalization operation + self.pixNorm = PixelwiseNorm() + + # leaky_relu: + self.lrelu = LeakyReLU(0.2) + + def forward(self, x): + """ + forward pass of the block + :param x: input to the module + :return: y => output + """ + # convert the tensor shape: + # y = th.unsqueeze(th.unsqueeze(x, -1), -1) + # perform the forward computations: + y = self.lrelu(self.conv_1(x)) + y = self.lrelu(self.conv_2(y)) + + # apply pixel norm + y = self.pixNorm(y) + + return y + + +class GenGeneralConvBlock(th.nn.Module): + """ Module implementing a general convolutional block """ + + def __init__(self, in_channels, out_channels, use_eql): + """ + constructor for the class + :param in_channels: number of input channels to the block + :param out_channels: number of output channels required + :param use_eql: whether to use equalized learning rate + """ + from torch.nn import LeakyReLU + from torch.nn.functional import interpolate + + super(GenGeneralConvBlock, self).__init__() + + self.upsample = lambda x: interpolate(x, scale_factor=2) + + if use_eql: + self.conv_1 = _equalized_conv2d(in_channels, out_channels, (3, 3), + pad=1, bias=True) + self.conv_2 = _equalized_conv2d(out_channels, out_channels, (3, 3), + pad=1, bias=True) + else: + from torch.nn import Conv2d + self.conv_1 = Conv2d(in_channels, out_channels, (3, 3), + padding=1, bias=True) + self.conv_2 = Conv2d(out_channels, out_channels, (3, 3), + padding=1, bias=True) + + # Pixelwise feature vector normalization operation + self.pixNorm = PixelwiseNorm() + + # leaky_relu: + self.lrelu = LeakyReLU(0.2) + + def forward(self, x): + """ + forward pass of the block + :param x: input + :return: y => output + """ + y = self.upsample(x) + y = self.pixNorm(self.lrelu(self.conv_1(y))) + y = self.pixNorm(self.lrelu(self.conv_2(y))) + + return y + + +# function to calculate the Exponential moving averages for the Generator weights +# This function updates the exponential average weights based on the current training +def update_average(model_tgt, model_src, beta): + """ + update the model_target using exponential moving averages + :param model_tgt: target model + :param model_src: source model + :param beta: value of decay beta + :return: None (updates the target model) + """ + + # utility function for toggling the gradient requirements of the models + def toggle_grad(model, requires_grad): + for p in model.parameters(): + p.requires_grad_(requires_grad) + + # turn off gradient calculation + toggle_grad(model_tgt, False) + toggle_grad(model_src, False) + + param_dict_src = dict(model_src.named_parameters()) + + for p_name, p_tgt in model_tgt.named_parameters(): + p_src = param_dict_src[p_name] + assert (p_src is not p_tgt) + p_tgt.copy_(beta * p_tgt + (1. - beta) * p_src) + + # turn back on the gradient calculation + toggle_grad(model_tgt, True) + toggle_grad(model_src, True) + + +class MinibatchStdDev(th.nn.Module): + """ + Minibatch standard deviation layer for the discriminator + """ + + def __init__(self): + """ + derived class constructor + """ + super(MinibatchStdDev, self).__init__() + + def forward(self, x, alpha=1e-8): + """ + forward pass of the layer + :param x: input activation volume + :param alpha: small number for numerical stability + :return: y => x appended with standard deviation constant map + """ + batch_size, _, height, width = x.shape + + # [B x C x H x W] Subtract mean over batch. + y = x - x.mean(dim=0, keepdim=True) + + # [1 x C x H x W] Calc standard deviation over batch + y = th.sqrt(y.pow(2.).mean(dim=0, keepdim=False) + alpha) + + # [1] Take average over feature_maps and pixels. + y = y.mean().view(1, 1, 1, 1) + + # [B x 1 x H x W] Replicate over group and pixels. + y = y.repeat(batch_size, 1, height, width) + + # [B x C x H x W] Append as new feature_map. + y = th.cat([x, y], 1) + + # return the computed values: + return y + + +class DisFinalBlock(th.nn.Module): + """ Final block for the Discriminator """ + + def __init__(self, in_channels, use_eql): + """ + constructor of the class + :param in_channels: number of input channels + :param use_eql: whether to use equalized learning rate + """ + from torch.nn import LeakyReLU + + super(DisFinalBlock, self).__init__() + + # declare the required modules for forward pass + self.batch_discriminator = MinibatchStdDev() + if use_eql: + self.conv_1 = _equalized_conv2d(in_channels + 1, in_channels, (3, 3), pad=1, bias=True) + self.conv_2 = _equalized_conv2d(in_channels, in_channels, (4, 4), bias=True) + # final conv layer emulates a fully connected layer + self.conv_3 = _equalized_conv2d(in_channels, 1, (1, 1), bias=True) + else: + from torch.nn import Conv2d + self.conv_1 = Conv2d(in_channels + 1, in_channels, (3, 3), padding=1, bias=True) + self.conv_2 = Conv2d(in_channels, in_channels, (4, 4), bias=True) + # final conv layer emulates a fully connected layer + self.conv_3 = Conv2d(in_channels, 1, (1, 1), bias=True) + + # leaky_relu: + self.lrelu = LeakyReLU(0.2) + + def forward(self, x): + """ + forward pass of the FinalBlock + :param x: input + :return: y => output + """ + # minibatch_std_dev layer + y = self.batch_discriminator(x) + + # define the computations + y = self.lrelu(self.conv_1(y)) + y = self.lrelu(self.conv_2(y)) + + # fully connected layer + y = self.conv_3(y) # This layer has linear activation + + # flatten the output raw discriminator scores + return y.view(-1) + + +class ConDisFinalBlock(th.nn.Module): + """ Final block for the Conditional Discriminator + Uses the Projection mechanism from the paper -> https://arxiv.org/pdf/1802.05637.pdf + """ + + def __init__(self, in_channels, num_classes, use_eql): + """ + constructor of the class + :param in_channels: number of input channels + :param num_classes: number of classes for conditional discrimination + :param use_eql: whether to use equalized learning rate + """ + from torch.nn import LeakyReLU, Embedding + + super(ConDisFinalBlock, self).__init__() + + # declare the required modules for forward pass + self.batch_discriminator = MinibatchStdDev() + + if use_eql: + self.conv_1 = _equalized_conv2d(in_channels + 1, in_channels, (3, 3), pad=1, bias=True) + self.conv_2 = _equalized_conv2d(in_channels, in_channels, (4, 4), bias=True) + + # final conv layer emulates a fully connected layer + self.conv_3 = _equalized_conv2d(in_channels, 1, (1, 1), bias=True) + else: + from torch.nn import Conv2d + self.conv_1 = Conv2d(in_channels + 1, in_channels, (3, 3), padding=1, bias=True) + self.conv_2 = Conv2d(in_channels, in_channels, (4, 4), bias=True) + + # final conv layer emulates a fully connected layer + self.conv_3 = Conv2d(in_channels, 1, (1, 1), bias=True) + + # we also need an embedding matrix for the label vectors + self.label_embedder = Embedding(num_classes, in_channels, max_norm=1) + + # leaky_relu: + self.lrelu = LeakyReLU(0.2) + + def forward(self, x, labels): + """ + forward pass of the FinalBlock + :param x: input + :param labels: samples' labels for conditional discrimination + Note that these are pure integer labels [Batch_size x 1] + :return: y => output + """ + # minibatch_std_dev layer + y = self.batch_discriminator(x) # [B x C x 4 x 4] + + # perform the forward pass + y = self.lrelu(self.conv_1(y)) # [B x C x 4 x 4] + + # obtain the computed features + y = self.lrelu(self.conv_2(y)) # [B x C x 1 x 1] + + # embed the labels + labels = self.label_embedder(labels) # [B x C] + + # compute the inner product with the label embeddings + y_ = th.squeeze(th.squeeze(y, dim=-1), dim=-1) # [B x C] + projection_scores = (y_ * labels).sum(dim=-1) # [B] + + # normal discrimination score + y = self.lrelu(self.conv_3(y)) # This layer has linear activation + + # calculate the total score + final_score = y.view(-1) + projection_scores + + # return the output raw discriminator scores + return final_score + + +class DisGeneralConvBlock(th.nn.Module): + """ General block in the discriminator """ + + def __init__(self, in_channels, out_channels, use_eql): + """ + constructor of the class + :param in_channels: number of input channels + :param out_channels: number of output channels + :param use_eql: whether to use equalized learning rate + """ + from torch.nn import AvgPool2d, LeakyReLU + + super(DisGeneralConvBlock, self).__init__() + + if use_eql: + self.conv_1 = _equalized_conv2d(in_channels, in_channels, (3, 3), pad=1, bias=True) + self.conv_2 = _equalized_conv2d(in_channels, out_channels, (3, 3), pad=1, bias=True) + else: + from torch.nn import Conv2d + self.conv_1 = Conv2d(in_channels, in_channels, (3, 3), padding=1, bias=True) + self.conv_2 = Conv2d(in_channels, out_channels, (3, 3), padding=1, bias=True) + + self.downSampler = AvgPool2d(2) + + # leaky_relu: + self.lrelu = LeakyReLU(0.2) + + def forward(self, x): + """ + forward pass of the module + :param x: input + :return: y => output + """ + # define the computations + y = self.lrelu(self.conv_1(x)) + y = self.lrelu(self.conv_2(y)) + y = self.downSampler(y) + + return y diff --git a/models/progan_model.py b/models/progan_model.py index 5b9661edfdf..7b44b8fb3c2 100644 --- a/models/progan_model.py +++ b/models/progan_model.py @@ -3,10 +3,9 @@ from .base_model import BaseModel from . import networks -from torch.nn import functional as F import numpy as np -from .networks import cal_gradient_penalty +from .progan_layers import update_average try: from apex import amp @@ -33,7 +32,7 @@ def modify_commandline_options(parser, is_train=True): parser.set_defaults(netG='progan', netD='progan', dataset_mode='single', beta1=0., ngf=512, ndf=512, gan_mode='wgangp') - parser.add_argument('--z_dim', type=int, default=32, help='random noise dim') + parser.add_argument('--z_dim', type=int, default=128, help='random noise dim') parser.add_argument('--max_steps', type=int, default=6, help='steps of growing') parser.add_argument('--steps_schedule', type=str, default='linear', help='type of when to turn to the next step: linear or fibonacci') @@ -103,7 +102,7 @@ def __init__(self, opt): """ total epochs """ - self.step = 1 + self.step = 0 """ current step of network, 1-6 """ @@ -138,7 +137,8 @@ def set_input(self, input): """ AtoB = self.opt.direction == 'AtoB' self.real_B = input['A' if AtoB else 'B'].to(self.device) - self.real_B = F.interpolate(self.real_B, size=(4 * 2 ** self.step, 4 * 2 ** self.step), mode='bilinear') + self.real_B = self.__progressive_downsampling(self.real_B, self.step, self.alpha) + #self.real_B = F.interpolate(self.real_B, size=(4 * 2 ** self.step, 4 * 2 ** self.step), mode='bilinear') self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): @@ -150,6 +150,7 @@ def forward(self): z = torch.randn((batch_size, self.z_dim, self.opt.crop_size // (2 ** self.max_steps), self.opt.crop_size // (2 ** self.max_steps)), device=self.device) + # z = torch.randn(batch_size, 512).to(self.device) self.fake_B = net(z, step=step, alpha=alpha) def backward_D(self): @@ -217,9 +218,9 @@ def optimize_parameters(self): def create_epochs_schedule(self, steps_schedule_type): if steps_schedule_type == 'fibonacci': - basic_weights = np.array([0., 1., 2., 3., 5., 8., 13., 21., 34.]) + basic_weights = np.array([3., 3., 3., 5., 8., 13., 21., 34., 55.]) else: - basic_weights = np.array([0., 1, 1, 1, 1, 1, 1, 1, 1]) + basic_weights = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1]) basic_weights = basic_weights[:self.max_steps + 1] epochs_schedule = self.total_steps * basic_weights / np.sum(basic_weights) epochs_schedule = epochs_schedule.astype(np.int) @@ -233,7 +234,7 @@ def update_inners_counters(self): self.iter += 1 self.alpha = min(1, (2. / (self.epochs_schedule[self.step])) * self.iter) if self.iter > self.epochs_schedule[self.step]: - print('turn to step % s' % str(self.step + 1)) + print('turn to step %s' % str(self.step + 1)) self.alpha = 0 self.iter = 0 self.step += 1 @@ -242,17 +243,46 @@ def update_inners_counters(self): self.alpha = 1 self.step = self.max_steps + print('new alpha: %s, new step: %s' % (self.alpha, self.step)) + def accumulate(self, decay=0.999): """ Accumulate weights from self.C to self.G with decay @param decay decay """ - par1 = dict(self.netG.named_parameters()) - par2 = dict(self.netC.named_parameters()) - - for k in par1.keys(): - par1[k].data.mul_(decay).add_(1 - decay, par2[k].data) + update_average(self.netG, self.netC, decay) def update_learning_rate(self): super(ProGanModel, self).update_learning_rate() self.update_inners_counters() + + def __progressive_downsampling(self, real_batch, depth, alpha): + """ + private helper for downsampling the original images in order to facilitate the + progressive growing of the layers. + :param real_batch: batch of real samples + :param depth: depth at which training is going on + :param alpha: current value of the fader alpha + :return: real_samples => modified real batch of samples + """ + + from torch.nn import AvgPool2d + from torch.nn.functional import interpolate + + # downsample the real_batch for the given depth + down_sample_factor = int(np.power(2, self.max_steps - depth)) + prior_downsample_factor = max(int(np.power(2, self.max_steps - depth + 1)), 0) + + ds_real_samples = AvgPool2d(down_sample_factor)(real_batch) + + if depth > 0: + prior_ds_real_samples = interpolate(AvgPool2d(prior_downsample_factor)(real_batch), + scale_factor=2) + else: + prior_ds_real_samples = ds_real_samples + + # real samples are a combination of ds_real_samples and prior_ds_real_samples + real_samples = (alpha * ds_real_samples) + ((1 - alpha) * prior_ds_real_samples) + + # return the so computed real_samples + return real_samples \ No newline at end of file From c15b89b0e37a1ca7be1278e30a704a5103f1b65e Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Tue, 28 Jul 2020 15:19:08 +0300 Subject: [PATCH 18/19] Stability improvements --- models/progan_model.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/models/progan_model.py b/models/progan_model.py index 7b44b8fb3c2..11a84f5a2ed 100644 --- a/models/progan_model.py +++ b/models/progan_model.py @@ -86,8 +86,8 @@ def __init__(self, opt): self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers; schedulers will be automatically created by function . - self.optimizer_C = torch.optim.Adam(self.netC.parameters(), lr=opt.lr, betas=(opt.beta1, 0.99)) - self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.99)) + self.optimizer_C = torch.optim.Adam(self.netC.parameters(), lr=opt.lr, betas=(opt.beta1, 0.99), eps=1e-6) + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.99), eps=1e-6) self.optimizers.append(self.optimizer_C) self.optimizers.append(self.optimizer_D) @@ -182,11 +182,12 @@ def backward_D(self): self.loss_D_gradpen = 10 * grad_penalty self.loss_D += self.loss_D_gradpen - if self.opt.apex: - with amp.scale_loss(self.loss_D, self.optimizer_D, loss_id=0) as loss_D_scaled: - loss_D_scaled.backward() - else: - self.loss_D.backward() + if not (torch.isinf(self.loss_D) or torch.isnan(self.loss_D) or torch.mean(torch.abs(self.loss_D)) > 100): + if self.opt.apex: + with amp.scale_loss(self.loss_D, self.optimizer_D, loss_id=0) as loss_D_scaled: + loss_D_scaled.backward() + else: + self.loss_D.backward() def backward_G(self): """Calculate GAN loss for the generator""" @@ -196,11 +197,12 @@ def backward_G(self): self.loss_G_GAN = self.criterionGAN(pred_fake, True) self.loss_G = self.loss_G_GAN - if self.opt.apex: - with amp.scale_loss(self.loss_G, self.optimizer_C, loss_id=1) as loss_G_scaled: - loss_G_scaled.backward() - else: - self.loss_G.backward() + if not (torch.isinf(self.loss_G) or torch.isnan(self.loss_G) or torch.mean(torch.abs(self.loss_G)) > 100): + if self.opt.apex: + with amp.scale_loss(self.loss_G, self.optimizer_C, loss_id=1) as loss_G_scaled: + loss_G_scaled.backward() + else: + self.loss_G.backward() def optimize_parameters(self): self.forward() # compute fake images: G(A) From ba0beb2ac48577415eeadfa47a5f8028a0450cc6 Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Tue, 28 Jul 2020 15:22:58 +0300 Subject: [PATCH 19/19] Stability improvements --- models/progan_layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/progan_layers.py b/models/progan_layers.py index 1f0ca63c02f..100810935f1 100644 --- a/models/progan_layers.py +++ b/models/progan_layers.py @@ -158,7 +158,7 @@ class PixelwiseNorm(th.nn.Module): def __init__(self): super(PixelwiseNorm, self).__init__() - def forward(self, x, alpha=1e-8): + def forward(self, x, alpha=1e-6): """ forward pass of the module :param x: input activations volume @@ -325,7 +325,7 @@ def __init__(self): """ super(MinibatchStdDev, self).__init__() - def forward(self, x, alpha=1e-8): + def forward(self, x, alpha=1e-6): """ forward pass of the layer :param x: input activation volume