-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmnist_web.py
More file actions
68 lines (54 loc) · 2.48 KB
/
mnist_web.py
File metadata and controls
68 lines (54 loc) · 2.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gzip
import os
#from urllib.request import urlretrieve
import numpy as np
def mnist(path=None):
r"""Return (train_images, train_labels, test_images, test_labels).
Args:
path (str): Directory containing MNIST. Default is
/home/USER/data/mnist or C:\Users\USER\data\mnist.
Create if nonexistant. Download any missing files.
Returns:
Tuple of (train_images, train_labels, test_images, test_labels), each
a matrix. Rows are examples. Columns of images are pixel values.
Columns of labels are a onehot encoding of the correct class.
"""
url = 'http://yann.lecun.com/exdb/mnist/'
files = ['train-images-idx3-ubyte.gz',
'train-labels-idx1-ubyte.gz',
't10k-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz']
if path is None:
# Set path to /home/USER/data/mnist or C:\Users\USER\data\mnist
path = os.path.join(os.path.expanduser('~'), 'data', 'mnist')
# Create path if it doesn't exist
# os.makedirs(path, exist_ok=True)
# Download any missing files
for file in files:
if file not in os.listdir(path):
urlretrieve(url + file, os.path.join(path, file))
print("Downloaded %s to %s" % (file, path))
def _images(path):
"""Return flattened images loaded from local file."""
with gzip.open(path) as f:
# First 16 bytes are magic_number, n_imgs, n_rows, n_cols
pixels = np.frombuffer(f.read(), '>B', offset=16)
return pixels.reshape(-1, 784).astype('float32') / 255
def _labels(path):
"""Return onehot labels loaded from local file."""
with gzip.open(path) as f:
# First 8 bytes are magic_number, n_labels
integer_labels = np.frombuffer(f.read(), '>B', offset=8)
def _onehot(integer_labels):
"""Return matrix whose rows are onehot encodings of integers."""
n_rows = len(integer_labels)
n_cols = integer_labels.max() + 1
onehot = np.zeros((n_rows, n_cols), dtype='bool')
onehot[np.arange(n_rows), integer_labels] = 1
return onehot
return _onehot(integer_labels)
train_images = _images(os.path.join(path, files[0]))
train_labels = _labels(os.path.join(path, files[1]))
test_images = _images(os.path.join(path, files[2]))
test_labels = _labels(os.path.join(path, files[3]))
return train_images, train_labels, test_images, test_labels