diff --git a/models/base_model.py b/models/base_model.py index 9cfb761897e..a387ec27173 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -227,3 +227,13 @@ def set_requires_grad(self, nets, requires_grad=False): if net is not None: for param in net.parameters(): param.requires_grad = requires_grad + + def make_data_parallel(self): + """Make models data parallel""" + if len(self.gpu_ids) == 0: + return + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + net = torch.nn.DataParallel(net, self.gpu_ids) # multi-GPUs + setattr(self, 'net' + name, net) diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 15bb72d8ddc..63b980d915d 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -3,7 +3,12 @@ from util.image_pool import ImagePool from .base_model import BaseModel from . import networks +from torch.utils.checkpoint import checkpoint +try: + from apex import amp +except ImportError: + print("Please install NVIDIA Apex for safe mixed precision if you want to use non default --opt_level") class CycleGANModel(BaseModel): """ @@ -96,6 +101,13 @@ def __init__(self, opt): self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) + if opt.apex: + [self.netG_A, self.netG_B, self.netD_A, self.netD_B], [self.optimizer_G, self.optimizer_D] = amp.initialize( + [self.netG_A, self.netG_B, self.netD_A, self.netD_B], [self.optimizer_G, self.optimizer_D], opt_level=opt.opt_level, num_losses=3) + + # need to be wrapped after amp.initialize + self.make_data_parallel() + def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. @@ -112,11 +124,17 @@ def set_input(self, input): def forward(self): """Run forward pass; called by both functions and .""" self.fake_B = self.netG_A(self.real_A) # G_A(A) - self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) + if not self.isTrain or not self.opt.checkpointing: + self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) + else: + self.rec_A = checkpoint(self.netG_B, self.fake_B) self.fake_A = self.netG_B(self.real_B) # G_B(B) - self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) + if not self.isTrain or not self.opt.checkpointing: + self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) + else: + self.rec_B = checkpoint(self.netG_A, self.fake_A) - def backward_D_basic(self, netD, real, fake): + def backward_D_basic(self, netD, real, fake, loss_id): """Calculate GAN loss for the discriminator Parameters: @@ -135,18 +153,23 @@ def backward_D_basic(self, netD, real, fake): loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss and calculate gradients loss_D = (loss_D_real + loss_D_fake) * 0.5 - loss_D.backward() + if self.opt.apex: + with amp.scale_loss(loss_D, self.optimizer_D, loss_id=loss_id) as loss_D_scaled: + loss_D_scaled.backward() + else: + loss_D.backward() + return loss_D def backward_D_A(self): """Calculate GAN loss for discriminator D_A""" fake_B = self.fake_B_pool.query(self.fake_B) - self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) + self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B, loss_id=0) def backward_D_B(self): """Calculate GAN loss for discriminator D_B""" fake_A = self.fake_A_pool.query(self.fake_A) - self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A, loss_id=1) def backward_G(self): """Calculate the loss for generators G_A and G_B""" @@ -175,7 +198,13 @@ def backward_G(self): self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss and calculate gradients self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B - self.loss_G.backward() + + if self.opt.apex: + with amp.scale_loss(self.loss_G, self.optimizer_G, loss_id=2) as loss_G_scaled: + loss_G_scaled.backward() + else: + self.loss_G.backward() + def optimize_parameters(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" diff --git a/models/networks.py b/models/networks.py index b3a10c99c20..2c9989819f3 100644 --- a/models/networks.py +++ b/models/networks.py @@ -2,8 +2,10 @@ import torch.nn as nn from torch.nn import init import functools -from torch.optim import lr_scheduler +from torch.nn.functional import interpolate +from torch.optim import lr_scheduler +import numpy as np ############################################################################### # Helper Functions @@ -98,7 +100,7 @@ def init_func(m): # define the initialization function net.apply(init_func) # apply the initialization function -def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): +def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], init_weights_=True): """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights Parameters: net (network) -- the network to be initialized @@ -111,12 +113,13 @@ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): if len(gpu_ids) > 0: assert(torch.cuda.is_available()) net.to(gpu_ids[0]) - net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs - init_weights(net, init_type, init_gain=init_gain) + if init_weights_: + init_weights(net, init_type, init_gain=init_gain) return net -def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]): +def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], + init_weights=True, **kwargs): """Create a generator Parameters: @@ -154,12 +157,16 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) elif netG == 'unet_256': net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + elif netG == 'progan': + net = GeneratorProGanV2(input_code_dim=input_nc, in_channel=ngf, max_steps=kwargs['max_steps']+1, out_channels=output_nc) + # net = GeneratorProGan(input_code_dim=input_nc, in_channel=ngf, max_steps=kwargs['max_steps'], out_channels=output_nc) else: raise NotImplementedError('Generator model name [%s] is not recognized' % netG) - return init_net(net, init_type, init_gain, gpu_ids) + return init_net(net, init_type, init_gain, gpu_ids, init_weights_=init_weights) -def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]): +def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[], + init_weights=True, **kwargs): """Create a discriminator Parameters: @@ -198,9 +205,12 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal' net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) elif netD == 'pixel': # classify if each pixel is real or fake net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) + elif netD == 'progan': + net = DiscriminatorProGanV2(feat_dim=ndf, in_dim=input_nc, max_steps=kwargs['max_steps']+1) + # net = DiscriminatorProGan(feat_dim=ndf, in_dim=input_nc, max_steps=kwargs['max_steps']) else: raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) - return init_net(net, init_type, init_gain, gpu_ids) + return init_net(net, init_type, init_gain, gpu_ids, init_weights_=init_weights) ############################################################################## @@ -613,3 +623,195 @@ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): def forward(self, input): """Standard forward.""" return self.net(input) + + +# ======================================================================================== +# Generator Module of ProGAN +# can be used with ProGAN or standalone (for inference) +# Thanks to https://raw.githubusercontent.com/akanimax/pro_gan_pytorch/ +# ======================================================================================== + + +class GeneratorProGanV2(nn.Module): + """ Generator of the GAN network """ + + def __init__(self, max_steps=7, input_code_dim=512, in_channel=512, out_channels=3, use_eql=True): + """ + constructor for the Generator class + :param max_steps: required depth of the Network + :param input_code_dim: size of the latent manifold + :param use_eql: whether to use equalized learning rate + """ + from .progan_layers import GenGeneralConvBlock, GenInitialBlock, _equalized_conv2d + + super(GeneratorProGanV2, self).__init__() + + assert input_code_dim != 0 and ((input_code_dim & (input_code_dim - 1)) == 0), \ + "latent size not a power of 2" + if max_steps >= 4: + assert in_channel >= np.power(2, max_steps - 4), "in_channel size will diminish to zero" + + # state of the generator: + self.use_eql = use_eql + self.depth = max_steps + self.latent_size = input_code_dim + self.channels_conv = in_channel + + # register the modules required for the GAN + self.initial_block = GenInitialBlock(in_channels=self.latent_size, out_channels=in_channel, use_eql=self.use_eql) + + # create a module list of the other required general convolution blocks + self.layers = nn.ModuleList([]) # initialize to empty list + + # create the ToRGB layers for various outputs: + if self.use_eql: + self.toRGB = lambda in_channels: \ + _equalized_conv2d(in_channels, out_channels, (1, 1), bias=True) + else: + from torch.nn import Conv2d + self.toRGB = lambda in_channels: Conv2d(in_channels, out_channels, (1, 1), bias=True) + + self.rgb_converters = nn.ModuleList([self.toRGB(self.channels_conv)]) + + # create the remaining layers + for i in range(self.depth - 1): + if i <= 2: + layer = GenGeneralConvBlock(self.channels_conv, + self.channels_conv, use_eql=self.use_eql) + rgb = self.toRGB(self.channels_conv) + else: + layer = GenGeneralConvBlock( + int(self.channels_conv // np.power(2, i - 3)), + int(self.channels_conv // np.power(2, i - 2)), + use_eql=self.use_eql + ) + rgb = self.toRGB(int(self.channels_conv // np.power(2, i - 2))) + self.layers.append(layer) + self.rgb_converters.append(rgb) + + # register the temporary upsampler + self.temporaryUpsampler = lambda x: interpolate(x, scale_factor=2) + + def forward(self, x, step, alpha): + """ + forward pass of the Generator + :param x: input noise + :param step: current depth from where output is required + :param alpha: value of alpha for fade-in effect + :return: y => output + """ + # step = step - 1 + assert step < self.depth, "Requested output depth cannot be produced" + + y = self.initial_block(x) + + if step > 0: + for block in self.layers[:step - 1]: + y = block(y) + + residual = self.rgb_converters[step - 1](self.temporaryUpsampler(y)) + straight = self.rgb_converters[step](self.layers[step - 1](y)) + + out = (alpha * straight) + ((1 - alpha) * residual) + + else: + out = self.rgb_converters[0](y) + + return out + +# ======================================================================================== +# Discriminator Module of ProGAN +# can be used with ProGAN or standalone (for inference). +# Thanks to https://raw.githubusercontent.com/akanimax/pro_gan_pytorch/ +# ======================================================================================== + + +class DiscriminatorProGanV2(nn.Module): + """ Discriminator of the GAN """ + + def __init__(self, max_steps=7, feat_dim=512, in_dim=3, use_eql=True): + """ + constructor for the class + :param max_steps: total height of the discriminator (Must be equal to the Generator depth) + :param feat_dim: size of the deepest features extracted + (Must be equal to Generator latent_size) + :param use_eql: whether to use equalized learning rate + """ + from torch.nn import ModuleList, AvgPool2d + from .progan_layers import DisGeneralConvBlock, DisFinalBlock, _equalized_conv2d + + super(DiscriminatorProGanV2, self).__init__() + + assert feat_dim != 0 and ((feat_dim & (feat_dim - 1)) == 0), \ + "latent size not a power of 2" + if max_steps >= 4: + assert feat_dim >= np.power(2, max_steps - 4), "feature size cannot be produced" + + # create state of the object + self.use_eql = use_eql + self.height = max_steps + self.feature_size = feat_dim + + self.final_block = DisFinalBlock(self.feature_size, use_eql=self.use_eql) + + # create a module list of the other required general convolution blocks + self.layers = ModuleList([]) # initialize to empty list + + # create the fromRGB layers for various inputs: + if self.use_eql: + self.fromRGB = lambda out_channels: \ + _equalized_conv2d(3, out_channels, (1, 1), bias=True) + else: + from torch.nn import Conv2d + self.fromRGB = lambda out_channels: Conv2d(in_dim, out_channels, (1, 1), bias=True) + + self.rgb_to_features = ModuleList([self.fromRGB(self.feature_size)]) + + # create the remaining layers + for i in range(self.height - 1): + if i > 2: + layer = DisGeneralConvBlock( + int(self.feature_size // np.power(2, i - 2)), + int(self.feature_size // np.power(2, i - 3)), + use_eql=self.use_eql + ) + rgb = self.fromRGB(int(self.feature_size // np.power(2, i - 2))) + else: + layer = DisGeneralConvBlock(self.feature_size, + self.feature_size, use_eql=self.use_eql) + rgb = self.fromRGB(self.feature_size) + + self.layers.append(layer) + self.rgb_to_features.append(rgb) + + # register the temporary downSampler + self.temporaryDownsampler = AvgPool2d(2) + + def forward(self, x, step, alpha): + """ + forward pass of the discriminator + :param x: input to the network + :param step: current height of operation (Progressive GAN) + :param alpha: current value of alpha for fade-in + :return: out => raw prediction values (WGAN-GP) + """ + # step = step - 1 + assert step < self.height, "Requested output depth cannot be produced" + + if step > 0: + residual = self.rgb_to_features[step - 1](self.temporaryDownsampler(x)) + + straight = self.layers[step - 1]( + self.rgb_to_features[step](x) + ) + + y = (alpha * straight) + ((1 - alpha) * residual) + + for block in reversed(self.layers[:step - 1]): + y = block(y) + else: + y = self.rgb_to_features[0](x) + + out = self.final_block(y) + + return out diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 939eb887ee3..70b18a4de8c 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -1,7 +1,14 @@ import torch +from torch.utils.checkpoint import checkpoint + from .base_model import BaseModel from . import networks +try: + from apex import amp +except ImportError: + print("Please install NVIDIA Apex for safe mixed precision if you want to use non default --opt_level") + class Pix2PixModel(BaseModel): """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. @@ -70,6 +77,12 @@ def __init__(self, opt): self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) + if opt.apex: + [self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = amp.initialize( + [self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=opt.opt_level, num_losses=2) + + self.make_data_parallel() + def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. @@ -99,7 +112,12 @@ def backward_D(self): self.loss_D_real = self.criterionGAN(pred_real, True) # combine loss and calculate gradients self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 - self.loss_D.backward() + + if self.opt.apex: + with amp.scale_loss(self.loss_D, self.optimizer_D, loss_id=0) as loss_D_scaled: + loss_D_scaled.backward() + else: + self.loss_D.backward() def backward_G(self): """Calculate GAN and L1 loss for the generator""" @@ -111,7 +129,12 @@ def backward_G(self): self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 # combine loss and calculate gradients self.loss_G = self.loss_G_GAN + self.loss_G_L1 - self.loss_G.backward() + + if self.opt.apex: + with amp.scale_loss(self.loss_G, self.optimizer_G, loss_id=1) as loss_G_scaled: + loss_G_scaled.backward() + else: + self.loss_G.backward() def optimize_parameters(self): self.forward() # compute fake images: G(A) diff --git a/models/progan_layers.py b/models/progan_layers.py new file mode 100644 index 00000000000..100810935f1 --- /dev/null +++ b/models/progan_layers.py @@ -0,0 +1,517 @@ +""" Module containing custom layers +Thanks to https://raw.githubusercontent.com/akanimax/pro_gan_pytorch/ +""" + +import torch as th + + +# extending Conv2D and Deconv2D layers for equalized learning rate logic +from torch.nn import Conv2d + + +class _equalized_conv2d(th.nn.Module): + """ conv2d with the concept of equalized learning rate + Args: + :param c_in: input channels + :param c_out: output channels + :param k_size: kernel size (h, w) should be a tuple or a single integer + :param stride: stride for conv + :param pad: padding + :param bias: whether to use bias or not + """ + + def __init__(self, c_in, c_out, k_size, stride=1, pad=0, bias=True): + """ constructor for the class """ + from torch.nn.modules.utils import _pair + from numpy import sqrt, prod + + super(_equalized_conv2d, self).__init__() + + # define the weight and bias if to be used + self.weight = th.nn.Parameter(th.nn.init.normal_( + th.empty(c_out, c_in, *_pair(k_size)) + )) + + self.use_bias = bias + self.stride = stride + self.pad = pad + + if self.use_bias: + self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0)) + + fan_in = prod(_pair(k_size)) * c_in # value of fan_in + self.scale = sqrt(2) / sqrt(fan_in) + + def forward(self, x): + """ + forward pass of the network + :param x: input + :return: y => output + """ + from torch.nn.functional import conv2d + + return conv2d(input=x, + weight=self.weight * self.scale, # scale the weight on runtime + bias=self.bias if self.use_bias else None, + stride=self.stride, + padding=self.pad) + + def extra_repr(self): + return ", ".join(map(str, self.weight.shape)) + + +class _equalized_deconv2d(th.nn.Module): + """ Transpose convolution using the equalized learning rate + Args: + :param c_in: input channels + :param c_out: output channels + :param k_size: kernel size + :param stride: stride for convolution transpose + :param pad: padding + :param bias: whether to use bias or not + """ + + def __init__(self, c_in, c_out, k_size, stride=1, pad=0, bias=True): + """ constructor for the class """ + from torch.nn.modules.utils import _pair + from numpy import sqrt + + super(_equalized_deconv2d, self).__init__() + + # define the weight and bias if to be used + self.weight = th.nn.Parameter(th.nn.init.normal_( + th.empty(c_in, c_out, *_pair(k_size)) + )) + + self.use_bias = bias + self.stride = stride + self.pad = pad + + if self.use_bias: + self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0)) + + fan_in = c_in # value of fan_in for deconv + self.scale = sqrt(2) / sqrt(fan_in) + + def forward(self, x): + """ + forward pass of the layer + :param x: input + :return: y => output + """ + from torch.nn.functional import conv_transpose2d + + return conv_transpose2d(input=x, + weight=self.weight * self.scale, # scale the weight on runtime + bias=self.bias if self.use_bias else None, + stride=self.stride, + padding=self.pad) + + def extra_repr(self): + return ", ".join(map(str, self.weight.shape)) + + +class _equalized_linear(th.nn.Module): + """ Linear layer using equalized learning rate + Args: + :param c_in: number of input channels + :param c_out: number of output channels + :param bias: whether to use bias with the linear layer + """ + + def __init__(self, c_in, c_out, bias=True): + """ + Linear layer modified for equalized learning rate + """ + from numpy import sqrt + + super(_equalized_linear, self).__init__() + + self.weight = th.nn.Parameter(th.nn.init.normal_( + th.empty(c_out, c_in) + )) + + self.use_bias = bias + + if self.use_bias: + self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0)) + + fan_in = c_in + self.scale = sqrt(2) / sqrt(fan_in) + + def forward(self, x): + """ + forward pass of the layer + :param x: input + :return: y => output + """ + from torch.nn.functional import linear + return linear(x, self.weight * self.scale, + self.bias if self.use_bias else None) + + +# ---------------------------------------------------------------------------- +# Pixelwise feature vector normalization. +# reference: https://github.com/tkarras/progressive_growing_of_gans/blob/master/networks.py#L120 +# ---------------------------------------------------------------------------- +class PixelwiseNorm(th.nn.Module): + def __init__(self): + super(PixelwiseNorm, self).__init__() + + def forward(self, x, alpha=1e-6): + """ + forward pass of the module + :param x: input activations volume + :param alpha: small number for numerical stability + :return: y => pixel normalized activations + """ + y = x.pow(2.).mean(dim=1, keepdim=True).add(alpha).sqrt() # [N1HW] + y = x / y # normalize the input x volume + return y + + +# ========================================================== +# Layers required for Building The generator and +# discriminator +# ========================================================== +class GenInitialBlock(th.nn.Module): + """ Module implementing the initial block of the input """ + + def __init__(self, in_channels, out_channels, use_eql): + """ + constructor for the inner class + :param in_channels: number of input channels to the block + :param use_eql: whether to use equalized learning rate + """ + from torch.nn import LeakyReLU + + super(GenInitialBlock, self).__init__() + + # EqualConv2d(input_code_dim, input_code_dim, 3, padding=1), + # PixelNorm(), + # nn.LeakyReLU(0.1), + # EqualConv2d(input_code_dim, in_channel, 3, padding=1), + # PixelNorm(), + # nn.LeakyReLU(0.1) + + if use_eql: + self.conv_1 = _equalized_conv2d(in_channels, out_channels, (3, 3), bias=True, pad=1) + self.conv_2 = _equalized_conv2d(out_channels, out_channels, (3, 3), bias=True, pad=1) + else: + self.conv_1 = Conv2d(in_channels, out_channels, (3, 3), padding=1, bias=True) + self.conv_2 = Conv2d(out_channels, out_channels, (3, 3), padding=1, bias=True) + # if use_eql: + # self.conv_1 = _equalized_deconv2d(in_channels, in_channels, (4, 4), bias=True) + # self.conv_2 = _equalized_conv2d(in_channels, in_channels, (3, 3), + # pad=1, bias=True) + # + # else: + # from torch.nn import Conv2d, ConvTranspose2d + # self.conv_1 = ConvTranspose2d(in_channels, in_channels, (4, 4), bias=True) + # self.conv_2 = Conv2d(in_channels, in_channels, (3, 3), padding=1, bias=True) + + # Pixelwise feature vector normalization operation + self.pixNorm = PixelwiseNorm() + + # leaky_relu: + self.lrelu = LeakyReLU(0.2) + + def forward(self, x): + """ + forward pass of the block + :param x: input to the module + :return: y => output + """ + # convert the tensor shape: + # y = th.unsqueeze(th.unsqueeze(x, -1), -1) + # perform the forward computations: + y = self.lrelu(self.conv_1(x)) + y = self.lrelu(self.conv_2(y)) + + # apply pixel norm + y = self.pixNorm(y) + + return y + + +class GenGeneralConvBlock(th.nn.Module): + """ Module implementing a general convolutional block """ + + def __init__(self, in_channels, out_channels, use_eql): + """ + constructor for the class + :param in_channels: number of input channels to the block + :param out_channels: number of output channels required + :param use_eql: whether to use equalized learning rate + """ + from torch.nn import LeakyReLU + from torch.nn.functional import interpolate + + super(GenGeneralConvBlock, self).__init__() + + self.upsample = lambda x: interpolate(x, scale_factor=2) + + if use_eql: + self.conv_1 = _equalized_conv2d(in_channels, out_channels, (3, 3), + pad=1, bias=True) + self.conv_2 = _equalized_conv2d(out_channels, out_channels, (3, 3), + pad=1, bias=True) + else: + from torch.nn import Conv2d + self.conv_1 = Conv2d(in_channels, out_channels, (3, 3), + padding=1, bias=True) + self.conv_2 = Conv2d(out_channels, out_channels, (3, 3), + padding=1, bias=True) + + # Pixelwise feature vector normalization operation + self.pixNorm = PixelwiseNorm() + + # leaky_relu: + self.lrelu = LeakyReLU(0.2) + + def forward(self, x): + """ + forward pass of the block + :param x: input + :return: y => output + """ + y = self.upsample(x) + y = self.pixNorm(self.lrelu(self.conv_1(y))) + y = self.pixNorm(self.lrelu(self.conv_2(y))) + + return y + + +# function to calculate the Exponential moving averages for the Generator weights +# This function updates the exponential average weights based on the current training +def update_average(model_tgt, model_src, beta): + """ + update the model_target using exponential moving averages + :param model_tgt: target model + :param model_src: source model + :param beta: value of decay beta + :return: None (updates the target model) + """ + + # utility function for toggling the gradient requirements of the models + def toggle_grad(model, requires_grad): + for p in model.parameters(): + p.requires_grad_(requires_grad) + + # turn off gradient calculation + toggle_grad(model_tgt, False) + toggle_grad(model_src, False) + + param_dict_src = dict(model_src.named_parameters()) + + for p_name, p_tgt in model_tgt.named_parameters(): + p_src = param_dict_src[p_name] + assert (p_src is not p_tgt) + p_tgt.copy_(beta * p_tgt + (1. - beta) * p_src) + + # turn back on the gradient calculation + toggle_grad(model_tgt, True) + toggle_grad(model_src, True) + + +class MinibatchStdDev(th.nn.Module): + """ + Minibatch standard deviation layer for the discriminator + """ + + def __init__(self): + """ + derived class constructor + """ + super(MinibatchStdDev, self).__init__() + + def forward(self, x, alpha=1e-6): + """ + forward pass of the layer + :param x: input activation volume + :param alpha: small number for numerical stability + :return: y => x appended with standard deviation constant map + """ + batch_size, _, height, width = x.shape + + # [B x C x H x W] Subtract mean over batch. + y = x - x.mean(dim=0, keepdim=True) + + # [1 x C x H x W] Calc standard deviation over batch + y = th.sqrt(y.pow(2.).mean(dim=0, keepdim=False) + alpha) + + # [1] Take average over feature_maps and pixels. + y = y.mean().view(1, 1, 1, 1) + + # [B x 1 x H x W] Replicate over group and pixels. + y = y.repeat(batch_size, 1, height, width) + + # [B x C x H x W] Append as new feature_map. + y = th.cat([x, y], 1) + + # return the computed values: + return y + + +class DisFinalBlock(th.nn.Module): + """ Final block for the Discriminator """ + + def __init__(self, in_channels, use_eql): + """ + constructor of the class + :param in_channels: number of input channels + :param use_eql: whether to use equalized learning rate + """ + from torch.nn import LeakyReLU + + super(DisFinalBlock, self).__init__() + + # declare the required modules for forward pass + self.batch_discriminator = MinibatchStdDev() + if use_eql: + self.conv_1 = _equalized_conv2d(in_channels + 1, in_channels, (3, 3), pad=1, bias=True) + self.conv_2 = _equalized_conv2d(in_channels, in_channels, (4, 4), bias=True) + # final conv layer emulates a fully connected layer + self.conv_3 = _equalized_conv2d(in_channels, 1, (1, 1), bias=True) + else: + from torch.nn import Conv2d + self.conv_1 = Conv2d(in_channels + 1, in_channels, (3, 3), padding=1, bias=True) + self.conv_2 = Conv2d(in_channels, in_channels, (4, 4), bias=True) + # final conv layer emulates a fully connected layer + self.conv_3 = Conv2d(in_channels, 1, (1, 1), bias=True) + + # leaky_relu: + self.lrelu = LeakyReLU(0.2) + + def forward(self, x): + """ + forward pass of the FinalBlock + :param x: input + :return: y => output + """ + # minibatch_std_dev layer + y = self.batch_discriminator(x) + + # define the computations + y = self.lrelu(self.conv_1(y)) + y = self.lrelu(self.conv_2(y)) + + # fully connected layer + y = self.conv_3(y) # This layer has linear activation + + # flatten the output raw discriminator scores + return y.view(-1) + + +class ConDisFinalBlock(th.nn.Module): + """ Final block for the Conditional Discriminator + Uses the Projection mechanism from the paper -> https://arxiv.org/pdf/1802.05637.pdf + """ + + def __init__(self, in_channels, num_classes, use_eql): + """ + constructor of the class + :param in_channels: number of input channels + :param num_classes: number of classes for conditional discrimination + :param use_eql: whether to use equalized learning rate + """ + from torch.nn import LeakyReLU, Embedding + + super(ConDisFinalBlock, self).__init__() + + # declare the required modules for forward pass + self.batch_discriminator = MinibatchStdDev() + + if use_eql: + self.conv_1 = _equalized_conv2d(in_channels + 1, in_channels, (3, 3), pad=1, bias=True) + self.conv_2 = _equalized_conv2d(in_channels, in_channels, (4, 4), bias=True) + + # final conv layer emulates a fully connected layer + self.conv_3 = _equalized_conv2d(in_channels, 1, (1, 1), bias=True) + else: + from torch.nn import Conv2d + self.conv_1 = Conv2d(in_channels + 1, in_channels, (3, 3), padding=1, bias=True) + self.conv_2 = Conv2d(in_channels, in_channels, (4, 4), bias=True) + + # final conv layer emulates a fully connected layer + self.conv_3 = Conv2d(in_channels, 1, (1, 1), bias=True) + + # we also need an embedding matrix for the label vectors + self.label_embedder = Embedding(num_classes, in_channels, max_norm=1) + + # leaky_relu: + self.lrelu = LeakyReLU(0.2) + + def forward(self, x, labels): + """ + forward pass of the FinalBlock + :param x: input + :param labels: samples' labels for conditional discrimination + Note that these are pure integer labels [Batch_size x 1] + :return: y => output + """ + # minibatch_std_dev layer + y = self.batch_discriminator(x) # [B x C x 4 x 4] + + # perform the forward pass + y = self.lrelu(self.conv_1(y)) # [B x C x 4 x 4] + + # obtain the computed features + y = self.lrelu(self.conv_2(y)) # [B x C x 1 x 1] + + # embed the labels + labels = self.label_embedder(labels) # [B x C] + + # compute the inner product with the label embeddings + y_ = th.squeeze(th.squeeze(y, dim=-1), dim=-1) # [B x C] + projection_scores = (y_ * labels).sum(dim=-1) # [B] + + # normal discrimination score + y = self.lrelu(self.conv_3(y)) # This layer has linear activation + + # calculate the total score + final_score = y.view(-1) + projection_scores + + # return the output raw discriminator scores + return final_score + + +class DisGeneralConvBlock(th.nn.Module): + """ General block in the discriminator """ + + def __init__(self, in_channels, out_channels, use_eql): + """ + constructor of the class + :param in_channels: number of input channels + :param out_channels: number of output channels + :param use_eql: whether to use equalized learning rate + """ + from torch.nn import AvgPool2d, LeakyReLU + + super(DisGeneralConvBlock, self).__init__() + + if use_eql: + self.conv_1 = _equalized_conv2d(in_channels, in_channels, (3, 3), pad=1, bias=True) + self.conv_2 = _equalized_conv2d(in_channels, out_channels, (3, 3), pad=1, bias=True) + else: + from torch.nn import Conv2d + self.conv_1 = Conv2d(in_channels, in_channels, (3, 3), padding=1, bias=True) + self.conv_2 = Conv2d(in_channels, out_channels, (3, 3), padding=1, bias=True) + + self.downSampler = AvgPool2d(2) + + # leaky_relu: + self.lrelu = LeakyReLU(0.2) + + def forward(self, x): + """ + forward pass of the module + :param x: input + :return: y => output + """ + # define the computations + y = self.lrelu(self.conv_1(x)) + y = self.lrelu(self.conv_2(y)) + y = self.downSampler(y) + + return y diff --git a/models/progan_model.py b/models/progan_model.py new file mode 100644 index 00000000000..11a84f5a2ed --- /dev/null +++ b/models/progan_model.py @@ -0,0 +1,290 @@ +import torch +from torch.autograd import grad + +from .base_model import BaseModel +from . import networks +import numpy as np + +from .progan_layers import update_average + +try: + from apex import amp +except ImportError: + print("Please install NVIDIA Apex for safe mixed precision if you want to use non default --opt_level") + + +class ProGanModel(BaseModel): + """ + This is an implementation of the paper "Progressive Growing of GANs": https://arxiv.org/abs/1710.10196. + Model requires dataset of type dataset_mode='single', generator netG='progan', discriminator netD='progan'. + Please note that opt.crop_size (default 256) == 4 * 2 ** opt.max_steps (default max_steps is 6). + ngf and ndf controlls dimensions of the backbone (128-512). + Network G is a master-generator (accumulates weights for eval) and network C (stands for current) is a + current trainable generator. + + See also: + https://github.com/tkarras/progressive_growing_of_gans + https://github.com/odegeasslbc/Progressive-GAN-pytorch + """ + + @staticmethod + def modify_commandline_options(parser, is_train=True): + + parser.set_defaults(netG='progan', netD='progan', dataset_mode='single', beta1=0., ngf=512, ndf=512, + gan_mode='wgangp') + parser.add_argument('--z_dim', type=int, default=128, help='random noise dim') + parser.add_argument('--max_steps', type=int, default=6, help='steps of growing') + parser.add_argument('--steps_schedule', type=str, default='linear', + help='type of when to turn to the next step: linear or fibonacci') + return parser + + def __init__(self, opt): + + BaseModel.__init__(self, opt) + self.z_dim = opt.z_dim + self.max_steps = opt.max_steps + # specify the training losses you want to print out. The training/test scripts will call + self.loss_names = ['G_GAN', 'D_real', 'D_fake', 'D'] + if self.opt.gan_mode == 'wgangp': + self.loss_names.append('D_gradpen') + # specify the images you want to save/display. The training/test scripts will call + self.visual_names = ['fake_B', 'real_B'] + # specify the models you want to save to the disk. The training/test scripts will call and + if self.isTrain: + self.model_names = ['G', 'D', 'C'] + else: # during test time, only load G + self.model_names = ['G'] + + if self.isTrain: + assert opt.crop_size == 4 * 2 ** self.max_steps + assert opt.beta1 == 0 + + # define networks (both generator and discriminator) + self.netG = networks.define_G(opt.z_dim, opt.input_nc, opt.ngf, opt.netG, opt.norm, + not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, + init_weights=False, max_steps=self.max_steps) + """ + resulting generator, not for training, just for eval + """ + + if self.isTrain: + self.netC = networks.define_G(opt.z_dim, opt.input_nc, opt.ngf, opt.netG, opt.norm, + not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, + init_weights=False, max_steps=self.max_steps) + """ + current training generator + """ + self.netD = networks.define_D(opt.input_nc, opt.ndf, opt.netD, + opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, + init_weights=False, max_steps=self.max_steps) + """ + current training discr. + """ + + if self.isTrain: + # define loss functions + self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) + self.criterionL1 = torch.nn.L1Loss() + # initialize optimizers; schedulers will be automatically created by function . + self.optimizer_C = torch.optim.Adam(self.netC.parameters(), lr=opt.lr, betas=(opt.beta1, 0.99), eps=1e-6) + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.99), eps=1e-6) + self.optimizers.append(self.optimizer_C) + self.optimizers.append(self.optimizer_D) + + if opt.apex: + [self.netC, self.netD], [self.optimizer_C, self.optimizer_D] = amp.initialize( + [self.netC, self.netD], [self.optimizer_C, self.optimizer_D], opt_level=opt.opt_level, num_losses=2) + + self.make_data_parallel() + + # inner counters + self.total_steps = opt.n_epochs + opt.n_epochs_decay + 1 + """ + total epochs + """ + self.step = 0 + """ + current step of network, 1-6 + """ + self.iter = 0 + """ + current iter, 0-(total epochs)//6 + """ + self.alpha = 0. + """ + current alpha rate to fuse different scales + """ + self.epochs_schedule = self.create_epochs_schedule(opt.steps_schedule) + """ + schedule when to turn to the next step + """ + + if self.isTrain: + assert self.total_steps > 12 + assert self.opt.crop_size % 2 ** self.max_steps == 0 + + # set fusing + self.netG.eval() + self.accumulate(0) + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): include the data itself and its metadata information. + + The option 'direction' can be used to swap images in domain A and domain B. + """ + AtoB = self.opt.direction == 'AtoB' + self.real_B = input['A' if AtoB else 'B'].to(self.device) + self.real_B = self.__progressive_downsampling(self.real_B, self.step, self.alpha) + #self.real_B = F.interpolate(self.real_B, size=(4 * 2 ** self.step, 4 * 2 ** self.step), mode='bilinear') + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + def forward(self): + """Run forward pass; called by both functions and .""" + net = self.netC if self.isTrain else self.netG + step = self.step if self.isTrain else self.max_steps + alpha = self.alpha if self.isTrain else 1 + batch_size = self.real_B.size(0) + z = torch.randn((batch_size, self.z_dim, self.opt.crop_size // (2 ** self.max_steps), + self.opt.crop_size // (2 ** self.max_steps)), + device=self.device) + # z = torch.randn(batch_size, 512).to(self.device) + self.fake_B = net(z, step=step, alpha=alpha) + + def backward_D(self): + """Calculate GAN loss for the discriminator""" + # Fake; stop backprop to the generator by detaching fake_B + fake_B = self.fake_B + pred_fake = self.netD(fake_B.detach(), step=self.step, alpha=self.alpha) + self.loss_D_fake = self.criterionGAN(pred_fake, False) + # Real + real_B = self.real_B + pred_real = self.netD(real_B, step=self.step, alpha=self.alpha) + self.loss_D_real = self.criterionGAN(pred_real, True) + if self.opt.gan_mode == 'wgangp': + # some correction of D loss + self.loss_D_real += 0.001 * (pred_real ** 2).mean() + # combine loss and calculate gradients + self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 + if self.opt.gan_mode == 'wgangp': + ### gradient penalty for D + b_size = fake_B.size(0) + eps = torch.rand(b_size, 1, 1, 1, dtype=fake_B.dtype, device=fake_B.device).to(fake_B.device) + x_hat = eps * real_B.data + (1 - eps) * fake_B.detach().data + x_hat.requires_grad = True + hat_predict = self.netD(x_hat, step=self.step, alpha=self.alpha) + grad_x_hat = grad( + outputs=hat_predict.sum(), inputs=x_hat, create_graph=True)[0] + grad_penalty = ((grad_x_hat.view(grad_x_hat.size(0), -1) + .norm(2, dim=1) - 1) ** 2).mean() + self.loss_D_gradpen = 10 * grad_penalty + self.loss_D += self.loss_D_gradpen + + if not (torch.isinf(self.loss_D) or torch.isnan(self.loss_D) or torch.mean(torch.abs(self.loss_D)) > 100): + if self.opt.apex: + with amp.scale_loss(self.loss_D, self.optimizer_D, loss_id=0) as loss_D_scaled: + loss_D_scaled.backward() + else: + self.loss_D.backward() + + def backward_G(self): + """Calculate GAN loss for the generator""" + # First, G(A) should fake the discriminator + fake_B = self.fake_B + pred_fake = self.netD(fake_B, step=self.step, alpha=self.alpha) + self.loss_G_GAN = self.criterionGAN(pred_fake, True) + self.loss_G = self.loss_G_GAN + + if not (torch.isinf(self.loss_G) or torch.isnan(self.loss_G) or torch.mean(torch.abs(self.loss_G)) > 100): + if self.opt.apex: + with amp.scale_loss(self.loss_G, self.optimizer_C, loss_id=1) as loss_G_scaled: + loss_G_scaled.backward() + else: + self.loss_G.backward() + + def optimize_parameters(self): + self.forward() # compute fake images: G(A) + # update D + self.set_requires_grad(self.netD, True) # enable backprop for D + self.optimizer_D.zero_grad() # set D's gradients to zero + self.backward_D() # calculate gradients for D + self.optimizer_D.step() # update D's weights + # update generator C + self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G + self.optimizer_C.zero_grad() # set G's gradients to zero + self.backward_G() # calculate graidents for G + self.optimizer_C.step() # udpate G's weights + self.accumulate() # fuse params + + def create_epochs_schedule(self, steps_schedule_type): + if steps_schedule_type == 'fibonacci': + basic_weights = np.array([3., 3., 3., 5., 8., 13., 21., 34., 55.]) + else: + basic_weights = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1]) + basic_weights = basic_weights[:self.max_steps + 1] + epochs_schedule = self.total_steps * basic_weights / np.sum(basic_weights) + epochs_schedule = epochs_schedule.astype(np.int) + print('schedule of step turning: %s' % str(epochs_schedule)) + return epochs_schedule + + def update_inners_counters(self): + """ + Update counters of iterations + """ + self.iter += 1 + self.alpha = min(1, (2. / (self.epochs_schedule[self.step])) * self.iter) + if self.iter > self.epochs_schedule[self.step]: + print('turn to step %s' % str(self.step + 1)) + self.alpha = 0 + self.iter = 0 + self.step += 1 + + if self.step > self.max_steps: + self.alpha = 1 + self.step = self.max_steps + + print('new alpha: %s, new step: %s' % (self.alpha, self.step)) + + def accumulate(self, decay=0.999): + """ + Accumulate weights from self.C to self.G with decay + @param decay decay + """ + update_average(self.netG, self.netC, decay) + + def update_learning_rate(self): + super(ProGanModel, self).update_learning_rate() + self.update_inners_counters() + + def __progressive_downsampling(self, real_batch, depth, alpha): + """ + private helper for downsampling the original images in order to facilitate the + progressive growing of the layers. + :param real_batch: batch of real samples + :param depth: depth at which training is going on + :param alpha: current value of the fader alpha + :return: real_samples => modified real batch of samples + """ + + from torch.nn import AvgPool2d + from torch.nn.functional import interpolate + + # downsample the real_batch for the given depth + down_sample_factor = int(np.power(2, self.max_steps - depth)) + prior_downsample_factor = max(int(np.power(2, self.max_steps - depth + 1)), 0) + + ds_real_samples = AvgPool2d(down_sample_factor)(real_batch) + + if depth > 0: + prior_ds_real_samples = interpolate(AvgPool2d(prior_downsample_factor)(real_batch), + scale_factor=2) + else: + prior_ds_real_samples = ds_real_samples + + # real samples are a combination of ds_real_samples and prior_ds_real_samples + real_samples = (alpha * ds_real_samples) + ((1 - alpha) * prior_ds_real_samples) + + # return the so computed real_samples + return real_samples \ No newline at end of file diff --git a/models/template_model.py b/models/template_model.py index 68cdaf6a9a2..45d3659ca4c 100644 --- a/models/template_model.py +++ b/models/template_model.py @@ -67,6 +67,8 @@ def __init__(self, opt): self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [self.optimizer] + # need to be wrapped after amp.initialize + self.make_data_parallel() # Our program will automatically call to define schedulers, load networks, and print networks def set_input(self, input): diff --git a/models/test_model.py b/models/test_model.py index fe15f40176e..a9ab50064db 100644 --- a/models/test_model.py +++ b/models/test_model.py @@ -48,6 +48,7 @@ def __init__(self, opt): # assigns the model to self.netG_[suffix] so that it can be loaded # please see setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self. + self.make_data_parallel() def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. diff --git a/options/train_options.py b/options/train_options.py index c8d5d2a92a9..575cd3cbbf5 100644 --- a/options/train_options.py +++ b/options/train_options.py @@ -36,5 +36,14 @@ def initialize(self, parser): parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') + # training optimizations + parser.add_argument('--checkpointing', action='store_true', + help='if true, it applies gradient checkpointing, saves memory but it makes the training slower') + parser.add_argument('--opt_level', default='O0', help='amp opt_level, default="O0" equals fp32 training') self.isTrain = True return parser + + def parse(self): + opt = BaseOptions.parse(self) + opt.apex = opt.opt_level != "O0" + return opt diff --git a/scripts/train_progan.sh b/scripts/train_progan.sh new file mode 100755 index 00000000000..aedbf5c4bfd --- /dev/null +++ b/scripts/train_progan.sh @@ -0,0 +1,2 @@ +set -ex +python train.py --dataroot ./datasets/facades --name progan_train --model progan --pool_size 50 --dataset_mode single --batch_size 1 --crop_size 256 --load_size 300 --preprocess resize_and_crop