diff --git a/README.md b/README.md index 560baf3..9899093 100644 --- a/README.md +++ b/README.md @@ -490,3 +490,14 @@ Thank you to Matthew Mann for his inspiring [simple port](https://github.com/man primaryClass = {cs.CV} } ``` + +```bibtex +@misc{chavdarova2021taming, + title={Taming GANs with Lookahead-Minmax}, + author={Tatjana Chavdarova and Matteo Pagliardini and Sebastian U. Stich and Francois Fleuret and Martin Jaggi}, + year={2021}, + eprint={2006.14567}, + archivePrefix={arXiv}, + primaryClass={stat.ML} +} +``` diff --git a/stylegan2_pytorch/cli.py b/stylegan2_pytorch/cli.py index 03f9863..7a09bbe 100644 --- a/stylegan2_pytorch/cli.py +++ b/stylegan2_pytorch/cli.py @@ -116,7 +116,10 @@ def train_from_folder( calculate_fid_num_images = 12800, clear_fid_cache = False, seed = 42, - log = False + log = False, + lookahead=False, + lookahead_alpha=0.5, + lookahead_k = 5 ): model_args = dict( name = name, @@ -155,7 +158,10 @@ def train_from_folder( calculate_fid_num_images = calculate_fid_num_images, clear_fid_cache = clear_fid_cache, mixed_prob = mixed_prob, - log = log + log = log, + lookahead = lookahead, + lookahead_alpha = lookahead_alpha, + lookahead_k = lookahead_k ) if generate: diff --git a/stylegan2_pytorch/stylegan2_pytorch.py b/stylegan2_pytorch/stylegan2_pytorch.py index 2f4dc3e..eff7a67 100644 --- a/stylegan2_pytorch/stylegan2_pytorch.py +++ b/stylegan2_pytorch/stylegan2_pytorch.py @@ -4,6 +4,7 @@ import fire import json +from collections import defaultdict from tqdm import tqdm from math import floor, log2 from random import random @@ -284,6 +285,30 @@ def slerp(val, low, high): res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high return res +# lookahead +class Lookahead(torch.optim.Optimizer): + def __init__(self, optimizer, alpha=0.5): + self.optimizer = optimizer + self.alpha = alpha + self.param_groups = self.optimizer.param_groups + self.state = defaultdict(dict) + + def lookahead_step(self): + for group in self.param_groups: + for fast in group["params"]: + param_state = self.state[fast] + if "slow_params" not in param_state: + param_state["slow_params"] = torch.zeros_like(fast.data) + param_state["slow_params"].copy_(fast.data) + slow = param_state["slow_params"] + # slow <- slow + alpha * (fast - slow) + slow += (fast.data - slow) * self.alpha + fast.data.copy_(slow) + + def step(self, closure = None): + loss = self.optimizer.step(closure) + return loss + # losses def gen_hinge_loss(fake, real): @@ -677,7 +702,7 @@ def forward(self, x): return x.squeeze(), quantize_loss class StyleGAN2(nn.Module): - def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8, network_capacity = 16, transparent = False, fp16 = False, cl_reg = False, steps = 1, lr = 1e-4, ttur_mult = 2, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, lr_mlp = 0.1, rank = 0): + def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8, network_capacity = 16, transparent = False, fp16 = False, cl_reg = False, steps = 1, lr = 1e-4, ttur_mult = 2, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, lr_mlp = 0.1, rank = 0, lookahead=False, lookahead_alpha=0.5): super().__init__() self.lr = lr self.steps = steps @@ -710,6 +735,11 @@ def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8 self.G_opt = Adam(generator_params, lr = self.lr, betas=(0.5, 0.9)) self.D_opt = Adam(self.D.parameters(), lr = self.lr * ttur_mult, betas=(0.5, 0.9)) + if lookahead: + # Wrap optimizers with the lookahead optimizer + self.G_opt = Lookahead(self.G_opt, alpha=lookahead_alpha) + self.D_opt = Lookahead(self.D_opt, alpha=lookahead_alpha) + # init weights self._init_weights() self.reset_parameter_averaging() @@ -792,6 +822,9 @@ def __init__( rank = 0, world_size = 1, log = False, + lookahead = False, + lookahead_alpha=0.5, + lookahead_k = 5, *args, **kwargs ): @@ -880,6 +913,10 @@ def __init__( self.logger = aim.Session(experiment=name) if log else None + self.lookahead = lookahead + self.lookahead_k = lookahead_k + self.lookahead_alpha = lookahead_alpha + @property def image_extension(self): return 'jpg' if not self.transparent else 'png' @@ -894,7 +931,7 @@ def hparams(self): def init_GAN(self): args, kwargs = self.GAN_params - self.GAN = StyleGAN2(lr = self.lr, lr_mlp = self.lr_mlp, ttur_mult = self.ttur_mult, image_size = self.image_size, network_capacity = self.network_capacity, fmap_max = self.fmap_max, transparent = self.transparent, fq_layers = self.fq_layers, fq_dict_size = self.fq_dict_size, attn_layers = self.attn_layers, fp16 = self.fp16, cl_reg = self.cl_reg, no_const = self.no_const, rank = self.rank, *args, **kwargs) + self.GAN = StyleGAN2(lr = self.lr, lr_mlp = self.lr_mlp, ttur_mult = self.ttur_mult, image_size = self.image_size, network_capacity = self.network_capacity, fmap_max = self.fmap_max, transparent = self.transparent, fq_layers = self.fq_layers, fq_dict_size = self.fq_dict_size, attn_layers = self.attn_layers, fp16 = self.fp16, cl_reg = self.cl_reg, no_const = self.no_const, rank = self.rank, lookahead=self.lookahead, lookahead_alpha=self.lookahead_alpha, *args, **kwargs) if self.is_ddp: ddp_kwargs = {'device_ids': [self.rank]} @@ -1112,15 +1149,22 @@ def train(self): self.GAN.G_opt.step() # calculate moving averages + if self.lookahead and (self.steps + 1) % self.lookahead_k == 0: + # Joint lookahead update + self.GAN.D_opt.lookahead_step() + self.GAN.G_opt.lookahead_step() + + if self.is_main: + self.GAN.EMA() if apply_path_penalty and not np.isnan(avg_pl_length): self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length) self.track(self.pl_mean, 'PL') - if self.is_main and self.steps % 10 == 0 and self.steps > 20000: + if self.is_main and not self.lookahead and self.steps % 10 == 0 and self.steps > 20000: self.GAN.EMA() - if self.is_main and self.steps <= 25000 and self.steps % 1000 == 2: + if self.is_main and not self.lookahead and self.steps <= 25000 and self.steps % 1000 == 2: self.GAN.reset_parameter_averaging() # save from NaN errors