Skip to content

Commit 517c52b

Browse files
author
Jaan Altosaar
authored
Update train_variational_autoencoder_pytorch.py
1 parent 8cf0057 commit 517c52b

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

train_variational_autoencoder_pytorch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
test_batch_size: 512
2828
max_iterations: 100000
2929
log_interval: 10000
30+
early_stopping_interval: 5
3031
n_samples: 128
3132
use_gpu: true
3233
train_dir: $TMPDIR
@@ -183,6 +184,7 @@ def evaluate(n_samples, model, variational, eval_data):
183184
if __name__ == '__main__':
184185
dictionary = yaml.load(config)
185186
cfg = nomen.Config(dictionary)
187+
cfg.parse_args()
186188
device = torch.device("cuda:0" if cfg.use_gpu else "cpu")
187189
torch.manual_seed(cfg.seed)
188190
np.random.seed(cfg.seed)
@@ -241,7 +243,7 @@ def evaluate(n_samples, model, variational, eval_data):
241243
else:
242244
num_no_improvement += 1
243245

244-
if num_no_improvement > 5:
246+
if num_no_improvement > cfg.early_stopping_interval:
245247
checkpoint = torch.load(cfg.train_dir / 'best_state_dict')
246248
model.load_state_dict(checkpoint['model'])
247249
variational.load_state_dict(checkpoint['variational'])

0 commit comments

Comments
 (0)