Skip to content

Commit 59f7778

Browse files
authored
Merge pull request #1476 from wangkuiyi/dataset
Simplify CIFAR/MNIST Data Package, Remove Scipy/sklearn package dependencies.
2 parents c444708 + 4eb54c2 commit 59f7778

File tree

8 files changed

+211
-112
lines changed

8 files changed

+211
-112
lines changed

python/paddle/v2/dataset/cifar.py

Lines changed: 32 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,61 @@
11
"""
2-
CIFAR Dataset.
3-
4-
URL: https://www.cs.toronto.edu/~kriz/cifar.html
5-
6-
the default train_creator, test_creator used for CIFAR-10 dataset.
2+
CIFAR dataset: https://www.cs.toronto.edu/~kriz/cifar.html
73
"""
84
import cPickle
95
import itertools
10-
import tarfile
11-
126
import numpy
7+
import paddle.v2.dataset.common
8+
import tarfile
139

14-
from config import download
15-
16-
__all__ = [
17-
'cifar_100_train_creator', 'cifar_100_test_creator', 'train_creator',
18-
'test_creator'
19-
]
10+
__all__ = ['train100', 'test100', 'train10', 'test10']
2011

21-
CIFAR10_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
12+
URL_PREFIX = 'https://www.cs.toronto.edu/~kriz/'
13+
CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz'
2214
CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
23-
CIFAR100_URL = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
15+
CIFAR100_URL = URL_PREFIX + 'cifar-100-python.tar.gz'
2416
CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
2517

2618

27-
def __read_batch__(filename, sub_name):
28-
def reader():
29-
def __read_one_batch_impl__(batch):
30-
data = batch['data']
31-
labels = batch.get('labels', batch.get('fine_labels', None))
32-
assert labels is not None
33-
for sample, label in itertools.izip(data, labels):
34-
yield (sample / 255.0).astype(numpy.float32), int(label)
19+
def reader_creator(filename, sub_name):
20+
def read_batch(batch):
21+
data = batch['data']
22+
labels = batch.get('labels', batch.get('fine_labels', None))
23+
assert labels is not None
24+
for sample, label in itertools.izip(data, labels):
25+
yield (sample / 255.0).astype(numpy.float32), int(label)
3526

27+
def reader():
3628
with tarfile.open(filename, mode='r') as f:
3729
names = (each_item.name for each_item in f
3830
if sub_name in each_item.name)
3931

4032
for name in names:
4133
batch = cPickle.load(f.extractfile(name))
42-
for item in __read_one_batch_impl__(batch):
34+
for item in read_batch(batch):
4335
yield item
4436

4537
return reader
4638

4739

48-
def cifar_100_train_creator():
49-
fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5)
50-
return __read_batch__(fn, 'train')
51-
52-
53-
def cifar_100_test_creator():
54-
fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5)
55-
return __read_batch__(fn, 'test')
56-
57-
58-
def train_creator():
59-
"""
60-
Default train reader creator. Use CIFAR-10 dataset.
61-
"""
62-
fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5)
63-
return __read_batch__(fn, 'data_batch')
40+
def train100():
41+
return reader_creator(
42+
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
43+
'train')
6444

6545

66-
def test_creator():
67-
"""
68-
Default test reader creator. Use CIFAR-10 dataset.
69-
"""
70-
fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5)
71-
return __read_batch__(fn, 'test_batch')
46+
def test100():
47+
return reader_creator(
48+
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
49+
'test')
7250

7351

74-
def unittest():
75-
for _ in train_creator()():
76-
pass
77-
for _ in test_creator()():
78-
pass
52+
def train10():
53+
return reader_creator(
54+
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
55+
'data_batch')
7956

8057

81-
if __name__ == '__main__':
82-
unittest()
58+
def test10():
59+
return reader_creator(
60+
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
61+
'test_batch')

python/paddle/v2/dataset/common.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import requests
2+
import hashlib
3+
import os
4+
import shutil
5+
6+
__all__ = ['DATA_HOME', 'download', 'md5file']
7+
8+
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
9+
10+
if not os.path.exists(DATA_HOME):
11+
os.makedirs(DATA_HOME)
12+
13+
14+
def md5file(fname):
15+
hash_md5 = hashlib.md5()
16+
f = open(fname, "rb")
17+
for chunk in iter(lambda: f.read(4096), b""):
18+
hash_md5.update(chunk)
19+
f.close()
20+
return hash_md5.hexdigest()
21+
22+
23+
def download(url, module_name, md5sum):
24+
dirname = os.path.join(DATA_HOME, module_name)
25+
if not os.path.exists(dirname):
26+
os.makedirs(dirname)
27+
28+
filename = os.path.join(dirname, url.split('/')[-1])
29+
if not (os.path.exists(filename) and md5file(filename) == md5sum):
30+
r = requests.get(url, stream=True)
31+
with open(filename, 'w') as f:
32+
shutil.copyfileobj(r.raw, f)
33+
34+
return filename

python/paddle/v2/dataset/config.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

python/paddle/v2/dataset/mnist.py

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,66 @@
1-
import sklearn.datasets.mldata
2-
import sklearn.model_selection
1+
"""
2+
MNIST dataset.
3+
"""
34
import numpy
4-
from config import DATA_HOME
5+
import paddle.v2.dataset.common
6+
import subprocess
57

6-
__all__ = ['train_creator', 'test_creator']
8+
__all__ = ['train', 'test']
79

10+
URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/'
11+
TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
12+
TEST_IMAGE_MD5 = '25e3cc63507ef6e98d5dc541e8672bb6'
13+
TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
14+
TEST_LABEL_MD5 = '4e9511fe019b2189026bd0421ba7b688'
15+
TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz'
16+
TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
17+
TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz'
18+
TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432'
819

9-
def __mnist_reader_creator__(data, target):
20+
21+
def reader_creator(image_filename, label_filename, buffer_size):
1022
def reader():
11-
n_samples = data.shape[0]
12-
for i in xrange(n_samples):
13-
yield (data[i] / 255.0).astype(numpy.float32), int(target[i])
23+
# According to http://stackoverflow.com/a/38061619/724872, we
24+
# cannot use standard package gzip here.
25+
m = subprocess.Popen(["zcat", image_filename], stdout=subprocess.PIPE)
26+
m.stdout.read(16) # skip some magic bytes
1427

15-
return reader
28+
l = subprocess.Popen(["zcat", label_filename], stdout=subprocess.PIPE)
29+
l.stdout.read(8) # skip some magic bytes
1630

31+
while True:
32+
labels = numpy.fromfile(
33+
l.stdout, 'ubyte', count=buffer_size).astype("int")
1734

18-
TEST_SIZE = 10000
35+
if labels.size != buffer_size:
36+
break # numpy.fromfile returns empty slice after EOF.
1937

20-
data = sklearn.datasets.mldata.fetch_mldata(
21-
"MNIST original", data_home=DATA_HOME)
22-
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
23-
data.data, data.target, test_size=TEST_SIZE, random_state=0)
38+
images = numpy.fromfile(
39+
m.stdout, 'ubyte', count=buffer_size * 28 * 28).reshape(
40+
(buffer_size, 28 * 28)).astype('float32')
2441

42+
images = images / 255.0 * 2.0 - 1.0
2543

26-
def train_creator():
27-
return __mnist_reader_creator__(X_train, y_train)
44+
for i in xrange(buffer_size):
45+
yield images[i, :], int(labels[i])
2846

47+
m.terminate()
48+
l.terminate()
2949

30-
def test_creator():
31-
return __mnist_reader_creator__(X_test, y_test)
50+
return reader
3251

3352

34-
def unittest():
35-
assert len(list(test_creator()())) == TEST_SIZE
53+
def train():
54+
return reader_creator(
55+
paddle.v2.dataset.common.download(TRAIN_IMAGE_URL, 'mnist',
56+
TRAIN_IMAGE_MD5),
57+
paddle.v2.dataset.common.download(TRAIN_LABEL_URL, 'mnist',
58+
TRAIN_LABEL_MD5), 100)
3659

3760

38-
if __name__ == '__main__':
39-
unittest()
61+
def test():
62+
return reader_creator(
63+
paddle.v2.dataset.common.download(TEST_IMAGE_URL, 'mnist',
64+
TEST_IMAGE_MD5),
65+
paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist',
66+
TEST_LABEL_MD5), 100)

python/paddle/v2/dataset/movielens.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import zipfile
2-
from config import download
2+
from common import download
33
import re
44
import random
55
import functools
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import paddle.v2.dataset.cifar
2+
import unittest
3+
4+
5+
class TestCIFAR(unittest.TestCase):
6+
def check_reader(self, reader):
7+
sum = 0
8+
label = 0
9+
for l in reader():
10+
self.assertEqual(l[0].size, 3072)
11+
if l[1] > label:
12+
label = l[1]
13+
sum += 1
14+
return sum, label
15+
16+
def test_test10(self):
17+
instances, max_label_value = self.check_reader(
18+
paddle.v2.dataset.cifar.test10())
19+
self.assertEqual(instances, 10000)
20+
self.assertEqual(max_label_value, 9)
21+
22+
def test_train10(self):
23+
instances, max_label_value = self.check_reader(
24+
paddle.v2.dataset.cifar.train10())
25+
self.assertEqual(instances, 50000)
26+
self.assertEqual(max_label_value, 9)
27+
28+
def test_test100(self):
29+
instances, max_label_value = self.check_reader(
30+
paddle.v2.dataset.cifar.test100())
31+
self.assertEqual(instances, 10000)
32+
self.assertEqual(max_label_value, 99)
33+
34+
def test_train100(self):
35+
instances, max_label_value = self.check_reader(
36+
paddle.v2.dataset.cifar.train100())
37+
self.assertEqual(instances, 50000)
38+
self.assertEqual(max_label_value, 99)
39+
40+
41+
if __name__ == '__main__':
42+
unittest.main()
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import paddle.v2.dataset.common
2+
import unittest
3+
import tempfile
4+
5+
6+
class TestCommon(unittest.TestCase):
7+
def test_md5file(self):
8+
_, temp_path = tempfile.mkstemp()
9+
with open(temp_path, 'w') as f:
10+
f.write("Hello\n")
11+
self.assertEqual('09f7e02f1290be211da707a266f153b3',
12+
paddle.v2.dataset.common.md5file(temp_path))
13+
14+
def test_download(self):
15+
yi_avatar = 'https://avatars0.githubusercontent.com/u/1548775?v=3&s=460'
16+
self.assertEqual(
17+
paddle.v2.dataset.common.DATA_HOME + '/test/1548775?v=3&s=460',
18+
paddle.v2.dataset.common.download(
19+
yi_avatar, 'test', 'f75287202d6622414c706c36c16f8e0d'))
20+
21+
22+
if __name__ == '__main__':
23+
unittest.main()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import paddle.v2.dataset.mnist
2+
import unittest
3+
4+
5+
class TestMNIST(unittest.TestCase):
6+
def check_reader(self, reader):
7+
sum = 0
8+
label = 0
9+
for l in reader():
10+
self.assertEqual(l[0].size, 784)
11+
if l[1] > label:
12+
label = l[1]
13+
sum += 1
14+
return sum, label
15+
16+
def test_train(self):
17+
instances, max_label_value = self.check_reader(
18+
paddle.v2.dataset.mnist.train())
19+
self.assertEqual(instances, 60000)
20+
self.assertEqual(max_label_value, 9)
21+
22+
def test_test(self):
23+
instances, max_label_value = self.check_reader(
24+
paddle.v2.dataset.mnist.test())
25+
self.assertEqual(instances, 10000)
26+
self.assertEqual(max_label_value, 9)
27+
28+
29+
if __name__ == '__main__':
30+
unittest.main()

0 commit comments

Comments
 (0)