diff --git a/train_vae.py b/train_vae.py index 66bf93e..bdf983b 100644 --- a/train_vae.py +++ b/train_vae.py @@ -153,6 +153,9 @@ def __init__( dl = self.accelerator.prepare(data_loader) self.dl = cycle(dl) + # step counter state + self.step = 0 + # optimizer self.opt_ae = torch.optim.AdamW(list(model.encoder.parameters())+ @@ -162,7 +165,7 @@ def __init__( lr=train_lr) self.opt_disc = torch.optim.AdamW(model.loss.discriminator.parameters(), lr=train_lr) min_lr = cfg['trainer']['min_lr'] - lr_lambda = lambda iter: max((1 - iter / train_num_steps) ** 0.95, min_lr/train_lr) + lr_lambda = lambda _: max((1 - self.step / train_num_steps) ** 0.95, min_lr/train_lr) self.lr_scheduler_ae = torch.optim.lr_scheduler.LambdaLR(self.opt_ae, lr_lambda=lr_lambda) self.lr_scheduler_disc = torch.optim.lr_scheduler.LambdaLR(self.opt_disc, lr_lambda=lr_lambda) # for logging results in a folder periodically @@ -174,9 +177,6 @@ def __init__( self.results_folder = Path(results_folder) self.results_folder.mkdir(exist_ok = True) - # step counter state - - self.step = 0 # prepare model, dataloader, optimizer with accelerator @@ -330,4 +330,4 @@ def train(self): if __name__ == "__main__": args = parse_args() main(args) - pass \ No newline at end of file + pass