Skip to content

Commit b96a500

Browse files
committed
Kaggle-Otto train/stats modes a bit more obvious.
Just like in MNIST.
1 parent 9fab276 commit b96a500

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

examples/Kaggle-Otto/run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def nnet():
5959
model.training()
6060
if epoch % 100 == 0:
6161
optimiser.hyperparams['lr'] /= 10
62-
train(train_data_x, train_data_y, model, optimiser, criterion, epoch, 100)
63-
train(train_data_x, train_data_y, model, optimiser, criterion, epoch, 100, 'stat')
62+
train(train_data_x, train_data_y, model, optimiser, criterion, epoch, 100, 'train')
63+
train(train_data_x, train_data_y, model, optimiser, criterion, epoch, 100, 'stats')
6464

6565
model.evaluate()
6666
validate(test_data_x, test_data_y, model, epoch, 100)

examples/Kaggle-Otto/train.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
from examples.utils import make_progressbar
55

6-
def train(dataset_x, dataset_y, model, optimiser, criterion, epoch, batch_size, mode=None):
7-
progress = make_progressbar('Training epoch #{}'.format(epoch), len(dataset_x))
6+
7+
def train(dataset_x, dataset_y, model, optimiser, criterion, epoch, batch_size, mode='train'):
8+
progress = make_progressbar('Training ({}) epoch #{}'.format(mode, epoch), len(dataset_x))
89
progress.start()
910

1011
shuffle = np.random.permutation(len(dataset_x))
@@ -17,12 +18,14 @@ def train(dataset_x, dataset_y, model, optimiser, criterion, epoch, batch_size,
1718
mini_batch_input[k] = dataset_x[shuffle[j * batch_size + k]]
1819
mini_batch_targets[k] = dataset_y[shuffle[j * batch_size + k]]
1920

20-
if mode is None:
21+
if mode == 'train':
2122
model.zero_grad_parameters()
2223
model.accumulate_gradients(mini_batch_input, mini_batch_targets, criterion)
2324
optimiser.update_parameters(model)
24-
else:
25+
elif mode == 'stats':
2526
model.accumulate_statistics(mini_batch_input)
27+
else:
28+
assert False, "Mode should be either 'train' or 'stats'"
2629

2730
progress.update((j+1) * batch_size)
2831

0 commit comments

Comments
 (0)