Skip to content

Commit 5ef5d77

Browse files
committed
Merge pull request #10 from lucasb-eyer/mnist-revamp
Mnist revamp
2 parents 8eb0949 + 77172b8 commit 5ef5d77

File tree

4 files changed

+30
-28
lines changed

4 files changed

+30
-28
lines changed

examples/MNIST/mnist.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,29 @@
11
import os
22
import gzip
33
import pickle
4-
import urllib
54
import sys
5+
6+
# Python 2/3 compatibility.
7+
try:
8+
from urllib.request import urlretrieve
9+
except ImportError:
10+
from urllib import urlretrieve
11+
12+
613
'''Adapted from theano tutorial'''
714

815

9-
def load_mnist(data_file = './mnist.pkl.gz'):
16+
def load_mnist(data_file = os.path.join(os.path.dirname(__file__), 'mnist.pkl.gz')):
1017

1118
if not os.path.exists(data_file):
12-
origin = ('http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz')
13-
print('Downloading data from %s' % origin)
14-
urllib.urlretrieve(origin, data_file)
19+
origin = 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz'
20+
print('Downloading data from {}'.format(origin))
21+
urlretrieve(origin, data_file)
1522

1623
print('... loading data')
1724

18-
f = gzip.open(data_file, 'rb')
19-
if sys.version_info[0] == 3:
20-
train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
21-
else:
22-
train_set, valid_set, test_set = pickle.load(f)
23-
f.close()
24-
25-
train_set_x, train_set_y = train_set
26-
valid_set_x, valid_set_y = valid_set
27-
test_set_x, test_set_y = test_set
28-
29-
return (train_set_x, train_set_y), (valid_set_x, valid_set_y), (test_set_x, test_set_y)
25+
with gzip.open(data_file, 'rb') as f:
26+
if sys.version_info[0] == 3:
27+
return pickle.load(f, encoding='latin1')
28+
else:
29+
return pickle.load(f)

examples/MNIST/run.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def main(params):
1818

1919
for epoch in range(100):
2020
model.training()
21-
train(train_set_x, train_set_y, model, optimiser, criterion, epoch, params['batch_size'])
22-
train(train_set_x, train_set_y, model, optimiser, criterion, epoch, params['batch_size'], 'stat')
21+
train(train_set_x, train_set_y, model, optimiser, criterion, epoch, params['batch_size'], 'train')
22+
train(train_set_x, train_set_y, model, optimiser, criterion, epoch, params['batch_size'], 'stats')
2323

2424
model.evaluate()
2525
validate(test_set_x, test_set_y, model, epoch, params['batch_size'])
@@ -29,4 +29,4 @@ def main(params):
2929
params = {}
3030
params['lr'] = 0.1
3131
params['batch_size'] = 64
32-
main(params)
32+
main(params)

examples/MNIST/test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import numpy as np
22
from progress_bar import *
33
import theano as _th
4-
from sklearn.metrics import accuracy_score
54

65
def validate(dataset_x, dataset_y, model, epoch, batch_size):
76
progress = make_progressbar('Testing', epoch, len(dataset_x))
87
progress.start()
98

109
mini_batch_input = np.empty(shape=(batch_size, 28*28), dtype=_th.config.floatX)
1110
mini_batch_targets = np.empty(shape=(batch_size, ), dtype=_th.config.floatX)
12-
accuracy = 0
11+
nerrors = 0
1312

1413
for j in range((dataset_x.shape[0] + batch_size - 1) // batch_size):
1514
progress.update(j * batch_size)
@@ -24,7 +23,8 @@ def validate(dataset_x, dataset_y, model, epoch, batch_size):
2423
mini_batch_prediction.resize((dataset_x.shape[0] - j * batch_size, ))
2524
mini_batch_targets.resize((dataset_x.shape[0] - j * batch_size, ))
2625

27-
accuracy = accuracy + accuracy_score(mini_batch_targets, mini_batch_prediction, normalize=False)
26+
nerrors += sum(mini_batch_targets != mini_batch_prediction)
2827

2928
progress.finish()
30-
print("Epoch #" + str(epoch) + ", Classification: " + str(float(accuracy) / dataset_x.shape[0] * 100.0))
29+
accuracy = 1 - float(nerrors)/dataset_x.shape[0]
30+
print("Epoch #{}, Classification accuracy: {:.2%} ({} errors)".format(epoch, accuracy, nerrors))

examples/MNIST/train.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import theano as _th
44

55

6-
def train(dataset_x, dataset_y, model, optimiser, criterion, epoch, batch_size, mode=None):
7-
progress = make_progressbar('Training', epoch, len(dataset_x))
6+
def train(dataset_x, dataset_y, model, optimiser, criterion, epoch, batch_size, mode='train'):
7+
progress = make_progressbar('Training ({})'.format(mode), epoch, len(dataset_x))
88
progress.start()
99

1010
shuffle = np.random.permutation(len(dataset_x))
@@ -17,12 +17,14 @@ def train(dataset_x, dataset_y, model, optimiser, criterion, epoch, batch_size,
1717
mini_batch_input[k] = dataset_x[shuffle[j * batch_size + k]]
1818
mini_batch_targets[k] = dataset_y[shuffle[j * batch_size + k]]
1919

20-
if mode is None:
20+
if mode == 'train':
2121
model.zero_grad_parameters()
2222
model.accumulate_gradients(mini_batch_input, mini_batch_targets, criterion)
2323
optimiser.update_parameters(model)
24-
else:
24+
elif mode == 'stats':
2525
model.accumulate_statistics(mini_batch_input)
26+
else:
27+
assert False, "Mode should be either 'train' or 'stats'"
2628

2729
progress.update(j * batch_size)
2830

0 commit comments

Comments
 (0)