Skip to content

Commit 845b214

Browse files
committed
Set total_epochs using standard epochs keyword
1 parent ee33263 commit 845b214

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

examples/image-vae/image_vae_baseline_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def run(gParams):
9797
test_file = candle.fetch_file(data_url + test_data, subdir='Examples/image_vae')
9898

9999
starting_epoch = 1
100-
total_epochs = None
100+
total_epochs = gParams['epochs']
101101

102102
rng_seed = 42
103103
torch.manual_seed(rng_seed)
@@ -263,7 +263,7 @@ def test(epoch, args):
263263
if total_epochs is None:
264264
trn_rng = itertools.count(start=starting_epoch)
265265
else:
266-
trn_rng = range(starting_epoch, total_epochs)
266+
trn_rng = range(starting_epoch, total_epochs + 1)
267267

268268
for epoch in trn_rng:
269269
for param_group in optimizer.param_groups:

examples/image-vae/image_vae_default_model.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ data_url = 'ftp://ftp.mcs.anl.gov/pub/candle/public/benchmarks/Examples/image_va
33
train_data = 'train.csv'
44
test_data = 'test.csv'
55
workers = 16
6+
epochs = None
67
batch_size = 256
78
grad_clip = 2.0
89
model_path = 'models'

0 commit comments

Comments
 (0)