|
1 | 1 | import os |
2 | 2 | import gzip |
3 | 3 | import pickle |
4 | | -import urllib |
5 | 4 | import sys |
| 5 | + |
| 6 | +# Python 2/3 compatibility. |
| 7 | +try: |
| 8 | + from urllib.request import urlretrieve |
| 9 | +except ImportError: |
| 10 | + from urllib import urlretrieve |
| 11 | + |
| 12 | + |
6 | 13 | '''Adapted from theano tutorial''' |
7 | 14 |
|
8 | 15 |
|
9 | | -def load_mnist(data_file = './mnist.pkl.gz'): |
| 16 | +def load_mnist(data_file = os.path.join(os.path.dirname(__file__), 'mnist.pkl.gz')): |
10 | 17 |
|
11 | 18 | if not os.path.exists(data_file): |
12 | | - origin = ('http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz') |
13 | | - print('Downloading data from %s' % origin) |
14 | | - urllib.urlretrieve(origin, data_file) |
| 19 | + origin = 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz' |
| 20 | + print('Downloading data from {}'.format(origin)) |
| 21 | + urlretrieve(origin, data_file) |
15 | 22 |
|
16 | 23 | print('... loading data') |
17 | 24 |
|
18 | | - f = gzip.open(data_file, 'rb') |
19 | | - if sys.version_info[0] == 3: |
20 | | - train_set, valid_set, test_set = pickle.load(f, encoding='latin1') |
21 | | - else: |
22 | | - train_set, valid_set, test_set = pickle.load(f) |
23 | | - f.close() |
24 | | - |
25 | | - train_set_x, train_set_y = train_set |
26 | | - valid_set_x, valid_set_y = valid_set |
27 | | - test_set_x, test_set_y = test_set |
28 | | - |
29 | | - return (train_set_x, train_set_y), (valid_set_x, valid_set_y), (test_set_x, test_set_y) |
| 25 | + with gzip.open(data_file, 'rb') as f: |
| 26 | + if sys.version_info[0] == 3: |
| 27 | + return pickle.load(f, encoding='latin1') |
| 28 | + else: |
| 29 | + return pickle.load(f) |
0 commit comments