Skip to content

Commit cb9d156

Browse files
committed
Merge branch 'feature/clean_mnist_v2' into feature/tester
2 parents db8566d + e44f053 commit cb9d156

File tree

7 files changed

+333
-30
lines changed

7 files changed

+333
-30
lines changed

python/paddle/v2/dataset/cifar.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""
2+
CIFAR Dataset.
3+
4+
URL: https://www.cs.toronto.edu/~kriz/cifar.html
5+
6+
the default train_creator, test_creator used for CIFAR-10 dataset.
7+
"""
8+
import cPickle
9+
import itertools
10+
import tarfile
11+
12+
import numpy
13+
14+
from common import download
15+
16+
__all__ = [
17+
'cifar_100_train_creator', 'cifar_100_test_creator', 'train_creator',
18+
'test_creator'
19+
]
20+
21+
CIFAR10_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
22+
CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
23+
CIFAR100_URL = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
24+
CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
25+
26+
27+
def __read_batch__(filename, sub_name):
28+
def reader():
29+
def __read_one_batch_impl__(batch):
30+
data = batch['data']
31+
labels = batch.get('labels', batch.get('fine_labels', None))
32+
assert labels is not None
33+
for sample, label in itertools.izip(data, labels):
34+
yield (sample / 255.0).astype(numpy.float32), int(label)
35+
36+
with tarfile.open(filename, mode='r') as f:
37+
names = (each_item.name for each_item in f
38+
if sub_name in each_item.name)
39+
40+
for name in names:
41+
batch = cPickle.load(f.extractfile(name))
42+
for item in __read_one_batch_impl__(batch):
43+
yield item
44+
45+
return reader
46+
47+
48+
def cifar_100_train_creator():
49+
fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5)
50+
return __read_batch__(fn, 'train')
51+
52+
53+
def cifar_100_test_creator():
54+
fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5)
55+
return __read_batch__(fn, 'test')
56+
57+
58+
def train_creator():
59+
"""
60+
Default train reader creator. Use CIFAR-10 dataset.
61+
"""
62+
fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5)
63+
return __read_batch__(fn, 'data_batch')
64+
65+
66+
def test_creator():
67+
"""
68+
Default test reader creator. Use CIFAR-10 dataset.
69+
"""
70+
fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5)
71+
return __read_batch__(fn, 'test_batch')
72+
73+
74+
def unittest():
75+
for _ in train_creator()():
76+
pass
77+
for _ in test_creator()():
78+
pass
79+
80+
81+
if __name__ == '__main__':
82+
unittest()

python/paddle/v2/dataset/common.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import requests
2+
import hashlib
3+
import os
4+
import shutil
5+
6+
__all__ = ['DATA_HOME', 'download', 'md5file']
7+
8+
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
9+
10+
if not os.path.exists(DATA_HOME):
11+
os.makedirs(DATA_HOME)
12+
13+
14+
def md5file(fname):
15+
hash_md5 = hashlib.md5()
16+
f = open(fname, "rb")
17+
for chunk in iter(lambda: f.read(4096), b""):
18+
hash_md5.update(chunk)
19+
f.close()
20+
return hash_md5.hexdigest()
21+
22+
23+
def download(url, module_name, md5sum):
24+
dirname = os.path.join(DATA_HOME, module_name)
25+
if not os.path.exists(dirname):
26+
os.makedirs(dirname)
27+
28+
filename = os.path.join(dirname, url.split('/')[-1])
29+
if not (os.path.exists(filename) and md5file(filename) == md5sum):
30+
# If file doesn't exist or MD5 doesn't match, then download.
31+
r = requests.get(url, stream=True)
32+
with open(filename, 'w') as f:
33+
shutil.copyfileobj(r.raw, f)
34+
35+
return filename

python/paddle/v2/dataset/config.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

python/paddle/v2/dataset/mnist.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,64 @@
1-
import sklearn.datasets.mldata
2-
import sklearn.model_selection
1+
import paddle.v2.dataset.common
2+
import subprocess
33
import numpy
4-
from config import DATA_HOME
54

6-
__all__ = ['train_creator', 'test_creator']
5+
__all__ = ['train', 'test']
76

7+
URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/'
88

9-
def __mnist_reader_creator__(data, target):
9+
TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
10+
TEST_IMAGE_MD5 = '25e3cc63507ef6e98d5dc541e8672bb6'
11+
TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
12+
TEST_LABEL_MD5 = '4e9511fe019b2189026bd0421ba7b688'
13+
TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz'
14+
TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
15+
TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz'
16+
TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432'
17+
18+
19+
def reader_creator(image_filename, label_filename, buffer_size):
1020
def reader():
11-
n_samples = data.shape[0]
12-
for i in xrange(n_samples):
13-
yield (data[i] / 255.0).astype(numpy.float32), int(target[i])
21+
# According to http://stackoverflow.com/a/38061619/724872, we
22+
# cannot use standard package gzip here.
23+
m = subprocess.Popen(["zcat", image_filename], stdout=subprocess.PIPE)
24+
m.stdout.read(16) # skip some magic bytes
1425

15-
return reader
26+
l = subprocess.Popen(["zcat", label_filename], stdout=subprocess.PIPE)
27+
l.stdout.read(8) # skip some magic bytes
1628

29+
while True:
30+
labels = numpy.fromfile(
31+
l.stdout, 'ubyte', count=buffer_size).astype("int")
1732

18-
TEST_SIZE = 10000
33+
if labels.size != buffer_size:
34+
break # numpy.fromfile returns empty slice after EOF.
1935

20-
data = sklearn.datasets.mldata.fetch_mldata(
21-
"MNIST original", data_home=DATA_HOME)
22-
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
23-
data.data, data.target, test_size=TEST_SIZE, random_state=0)
36+
images = numpy.fromfile(
37+
m.stdout, 'ubyte', count=buffer_size * 28 * 28).reshape(
38+
(buffer_size, 28 * 28)).astype('float32')
2439

40+
images = images / 255.0 * 2.0 - 1.0
2541

26-
def train_creator():
27-
return __mnist_reader_creator__(X_train, y_train)
42+
for i in xrange(buffer_size):
43+
yield images[i, :], labels[i]
2844

45+
m.terminate()
46+
l.terminate()
2947

30-
def test_creator():
31-
return __mnist_reader_creator__(X_test, y_test)
48+
return reader()
3249

3350

34-
def unittest():
35-
assert len(list(test_creator()())) == TEST_SIZE
51+
def train():
52+
return reader_creator(
53+
paddle.v2.dataset.common.download(TRAIN_IMAGE_URL, 'mnist',
54+
TRAIN_IMAGE_MD5),
55+
paddle.v2.dataset.common.download(TRAIN_LABEL_URL, 'mnist',
56+
TRAIN_LABEL_MD5), 100)
3657

3758

38-
if __name__ == '__main__':
39-
unittest()
59+
def test():
60+
return reader_creator(
61+
paddle.v2.dataset.common.download(TEST_IMAGE_URL, 'mnist',
62+
TEST_IMAGE_MD5),
63+
paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist',
64+
TEST_LABEL_MD5), 100)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import zipfile
2+
from common import download
3+
import re
4+
import random
5+
import functools
6+
7+
__all__ = ['train_creator', 'test_creator']
8+
9+
10+
class MovieInfo(object):
11+
def __init__(self, index, categories, title):
12+
self.index = int(index)
13+
self.categories = categories
14+
self.title = title
15+
16+
def value(self):
17+
return [
18+
self.index, [CATEGORIES_DICT[c] for c in self.categories],
19+
[MOVIE_TITLE_DICT[w.lower()] for w in self.title.split()]
20+
]
21+
22+
23+
class UserInfo(object):
24+
def __init__(self, index, gender, age, job_id):
25+
self.index = int(index)
26+
self.is_male = gender == 'M'
27+
self.age = [1, 18, 25, 35, 45, 50, 56].index(int(age))
28+
self.job_id = int(job_id)
29+
30+
def value(self):
31+
return [self.index, 0 if self.is_male else 1, self.age, self.job_id]
32+
33+
34+
MOVIE_INFO = None
35+
MOVIE_TITLE_DICT = None
36+
CATEGORIES_DICT = None
37+
USER_INFO = None
38+
39+
40+
def __initialize_meta_info__():
41+
fn = download(
42+
url='http://files.grouplens.org/datasets/movielens/ml-1m.zip',
43+
md5='c4d9eecfca2ab87c1945afe126590906')
44+
global MOVIE_INFO
45+
if MOVIE_INFO is None:
46+
pattern = re.compile(r'^(.*)\((\d+)\)$')
47+
with zipfile.ZipFile(file=fn) as package:
48+
for info in package.infolist():
49+
assert isinstance(info, zipfile.ZipInfo)
50+
MOVIE_INFO = dict()
51+
title_word_set = set()
52+
categories_set = set()
53+
with package.open('ml-1m/movies.dat') as movie_file:
54+
for i, line in enumerate(movie_file):
55+
movie_id, title, categories = line.strip().split('::')
56+
categories = categories.split('|')
57+
for c in categories:
58+
categories_set.add(c)
59+
title = pattern.match(title).group(1)
60+
MOVIE_INFO[int(movie_id)] = MovieInfo(
61+
index=movie_id, categories=categories, title=title)
62+
for w in title.split():
63+
title_word_set.add(w.lower())
64+
65+
global MOVIE_TITLE_DICT
66+
MOVIE_TITLE_DICT = dict()
67+
for i, w in enumerate(title_word_set):
68+
MOVIE_TITLE_DICT[w] = i
69+
70+
global CATEGORIES_DICT
71+
CATEGORIES_DICT = dict()
72+
for i, c in enumerate(categories_set):
73+
CATEGORIES_DICT[c] = i
74+
75+
global USER_INFO
76+
USER_INFO = dict()
77+
with package.open('ml-1m/users.dat') as user_file:
78+
for line in user_file:
79+
uid, gender, age, job, _ = line.strip().split("::")
80+
USER_INFO[int(uid)] = UserInfo(
81+
index=uid, gender=gender, age=age, job_id=job)
82+
return fn
83+
84+
85+
def __reader__(rand_seed=0, test_ratio=0.1, is_test=False):
86+
fn = __initialize_meta_info__()
87+
rand = random.Random(x=rand_seed)
88+
with zipfile.ZipFile(file=fn) as package:
89+
with package.open('ml-1m/ratings.dat') as rating:
90+
for line in rating:
91+
if (rand.random() < test_ratio) == is_test:
92+
uid, mov_id, rating, _ = line.strip().split("::")
93+
uid = int(uid)
94+
mov_id = int(mov_id)
95+
rating = float(rating) * 2 - 5.0
96+
97+
mov = MOVIE_INFO[mov_id]
98+
usr = USER_INFO[uid]
99+
yield usr.value() + mov.value() + [[rating]]
100+
101+
102+
def __reader_creator__(**kwargs):
103+
return lambda: __reader__(**kwargs)
104+
105+
106+
train_creator = functools.partial(__reader_creator__, is_test=False)
107+
test_creator = functools.partial(__reader_creator__, is_test=True)
108+
109+
110+
def unittest():
111+
for train_count, _ in enumerate(train_creator()()):
112+
pass
113+
for test_count, _ in enumerate(test_creator()()):
114+
pass
115+
116+
print train_count, test_count
117+
118+
119+
if __name__ == '__main__':
120+
unittest()
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import paddle.v2.dataset.common
2+
import unittest
3+
import tempfile
4+
5+
6+
class TestCommon(unittest.TestCase):
7+
def test_md5file(self):
8+
_, temp_path = tempfile.mkstemp()
9+
with open(temp_path, 'w') as f:
10+
f.write("Hello\n")
11+
self.assertEqual('09f7e02f1290be211da707a266f153b3',
12+
paddle.v2.dataset.common.md5file(temp_path))
13+
14+
def test_download(self):
15+
yi_avatar = 'https://avatars0.githubusercontent.com/u/1548775?v=3&s=460'
16+
self.assertEqual(
17+
paddle.v2.dataset.common.DATA_HOME + '/test/1548775?v=3&s=460',
18+
paddle.v2.dataset.common.download(
19+
yi_avatar, 'test', 'f75287202d6622414c706c36c16f8e0d'))
20+
21+
22+
if __name__ == '__main__':
23+
unittest.main()
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import paddle.v2.dataset.mnist
2+
import unittest
3+
4+
5+
class TestMNIST(unittest.TestCase):
6+
def check_reader(self, reader):
7+
sum = 0
8+
for l in reader:
9+
self.assertEqual(l[0].size, 784)
10+
self.assertEqual(l[1].size, 1)
11+
self.assertLess(l[1], 10)
12+
self.assertGreaterEqual(l[1], 0)
13+
sum += 1
14+
return sum
15+
16+
def test_train(self):
17+
self.assertEqual(
18+
self.check_reader(paddle.v2.dataset.mnist.train()), 60000)
19+
20+
def test_test(self):
21+
self.assertEqual(
22+
self.check_reader(paddle.v2.dataset.mnist.test()), 10000)
23+
24+
25+
if __name__ == '__main__':
26+
unittest.main()

0 commit comments

Comments
 (0)