Skip to content

Commit 9053a98

Browse files
committed
MNIST example: clarify train/stats runs in code.
1 parent e85e5ea commit 9053a98

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

examples/MNIST/run.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def main(params):
1818

1919
for epoch in range(100):
2020
model.training()
21-
train(train_set_x, train_set_y, model, optimiser, criterion, epoch, params['batch_size'])
22-
train(train_set_x, train_set_y, model, optimiser, criterion, epoch, params['batch_size'], 'stat')
21+
train(train_set_x, train_set_y, model, optimiser, criterion, epoch, params['batch_size'], 'train')
22+
train(train_set_x, train_set_y, model, optimiser, criterion, epoch, params['batch_size'], 'stats')
2323

2424
model.evaluate()
2525
validate(test_set_x, test_set_y, model, epoch, params['batch_size'])
@@ -29,4 +29,4 @@ def main(params):
2929
params = {}
3030
params['lr'] = 0.1
3131
params['batch_size'] = 64
32-
main(params)
32+
main(params)

examples/MNIST/train.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import theano as _th
44

55

6-
def train(dataset_x, dataset_y, model, optimiser, criterion, epoch, batch_size, mode=None):
7-
progress = make_progressbar('Training', epoch, len(dataset_x))
6+
def train(dataset_x, dataset_y, model, optimiser, criterion, epoch, batch_size, mode='train'):
7+
progress = make_progressbar('Training ({})'.format(mode), epoch, len(dataset_x))
88
progress.start()
99

1010
shuffle = np.random.permutation(len(dataset_x))
@@ -17,12 +17,14 @@ def train(dataset_x, dataset_y, model, optimiser, criterion, epoch, batch_size,
1717
mini_batch_input[k] = dataset_x[shuffle[j * batch_size + k]]
1818
mini_batch_targets[k] = dataset_y[shuffle[j * batch_size + k]]
1919

20-
if mode is None:
20+
if mode == 'train':
2121
model.zero_grad_parameters()
2222
model.accumulate_gradients(mini_batch_input, mini_batch_targets, criterion)
2323
optimiser.update_parameters(model)
24-
else:
24+
elif mode == 'stats':
2525
model.accumulate_statistics(mini_batch_input)
26+
else:
27+
assert False, "Mode should be either 'train' or 'stats'"
2628

2729
progress.update(j * batch_size)
2830

0 commit comments

Comments
 (0)