Skip to content

Commit 559efcd

Browse files
committed
Merge branch 'develop' of github.com:baidu/Paddle into feature/clean_mnist_v2
2 parents faa43e3 + 59f7778 commit 559efcd

File tree

5 files changed

+91
-66
lines changed

5 files changed

+91
-66
lines changed

python/paddle/v2/dataset/cifar.py

Lines changed: 32 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,61 @@
11
"""
2-
CIFAR Dataset.
3-
4-
URL: https://www.cs.toronto.edu/~kriz/cifar.html
5-
6-
the default train_creator, test_creator used for CIFAR-10 dataset.
2+
CIFAR dataset: https://www.cs.toronto.edu/~kriz/cifar.html
73
"""
84
import cPickle
95
import itertools
10-
import tarfile
11-
126
import numpy
7+
import paddle.v2.dataset.common
8+
import tarfile
139

14-
from common import download
15-
16-
__all__ = [
17-
'cifar_100_train_creator', 'cifar_100_test_creator', 'train_creator',
18-
'test_creator'
19-
]
10+
__all__ = ['train100', 'test100', 'train10', 'test10']
2011

21-
CIFAR10_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
12+
URL_PREFIX = 'https://www.cs.toronto.edu/~kriz/'
13+
CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz'
2214
CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
23-
CIFAR100_URL = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
15+
CIFAR100_URL = URL_PREFIX + 'cifar-100-python.tar.gz'
2416
CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
2517

2618

27-
def __read_batch__(filename, sub_name):
28-
def reader():
29-
def __read_one_batch_impl__(batch):
30-
data = batch['data']
31-
labels = batch.get('labels', batch.get('fine_labels', None))
32-
assert labels is not None
33-
for sample, label in itertools.izip(data, labels):
34-
yield (sample / 255.0).astype(numpy.float32), int(label)
19+
def reader_creator(filename, sub_name):
20+
def read_batch(batch):
21+
data = batch['data']
22+
labels = batch.get('labels', batch.get('fine_labels', None))
23+
assert labels is not None
24+
for sample, label in itertools.izip(data, labels):
25+
yield (sample / 255.0).astype(numpy.float32), int(label)
3526

27+
def reader():
3628
with tarfile.open(filename, mode='r') as f:
3729
names = (each_item.name for each_item in f
3830
if sub_name in each_item.name)
3931

4032
for name in names:
4133
batch = cPickle.load(f.extractfile(name))
42-
for item in __read_one_batch_impl__(batch):
34+
for item in read_batch(batch):
4335
yield item
4436

4537
return reader
4638

4739

48-
def cifar_100_train_creator():
49-
fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5)
50-
return __read_batch__(fn, 'train')
51-
52-
53-
def cifar_100_test_creator():
54-
fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5)
55-
return __read_batch__(fn, 'test')
56-
57-
58-
def train_creator():
59-
"""
60-
Default train reader creator. Use CIFAR-10 dataset.
61-
"""
62-
fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5)
63-
return __read_batch__(fn, 'data_batch')
40+
def train100():
41+
return reader_creator(
42+
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
43+
'train')
6444

6545

66-
def test_creator():
67-
"""
68-
Default test reader creator. Use CIFAR-10 dataset.
69-
"""
70-
fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5)
71-
return __read_batch__(fn, 'test_batch')
46+
def test100():
47+
return reader_creator(
48+
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
49+
'test')
7250

7351

74-
def unittest():
75-
for _ in train_creator()():
76-
pass
77-
for _ in test_creator()():
78-
pass
52+
def train10():
53+
return reader_creator(
54+
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
55+
'data_batch')
7956

8057

81-
if __name__ == '__main__':
82-
unittest()
58+
def test10():
59+
return reader_creator(
60+
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
61+
'test_batch')

python/paddle/v2/dataset/common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def download(url, module_name, md5sum):
2727

2828
filename = os.path.join(dirname, url.split('/')[-1])
2929
if not (os.path.exists(filename) and md5file(filename) == md5sum):
30-
# If file doesn't exist or MD5 doesn't match, then download.
3130
r = requests.get(url, stream=True)
3231
with open(filename, 'w') as f:
3332
shutil.copyfileobj(r.raw, f)

python/paddle/v2/dataset/mnist.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1+
"""
2+
MNIST dataset.
3+
"""
14
import paddle.v2.dataset.common
25
import subprocess
36
import numpy
47
import platform
5-
68
__all__ = ['train', 'test']
79

810
URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/'
9-
1011
TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
1112
TEST_IMAGE_MD5 = '25e3cc63507ef6e98d5dc541e8672bb6'
1213
TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
@@ -48,7 +49,7 @@ def reader():
4849
images = images / 255.0 * 2.0 - 1.0
4950

5051
for i in xrange(buffer_size):
51-
yield images[i, :], labels[i]
52+
yield images[i, :], int(labels[i])
5253

5354
m.terminate()
5455
l.terminate()
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import paddle.v2.dataset.cifar
2+
import unittest
3+
4+
5+
class TestCIFAR(unittest.TestCase):
6+
def check_reader(self, reader):
7+
sum = 0
8+
label = 0
9+
for l in reader():
10+
self.assertEqual(l[0].size, 3072)
11+
if l[1] > label:
12+
label = l[1]
13+
sum += 1
14+
return sum, label
15+
16+
def test_test10(self):
17+
instances, max_label_value = self.check_reader(
18+
paddle.v2.dataset.cifar.test10())
19+
self.assertEqual(instances, 10000)
20+
self.assertEqual(max_label_value, 9)
21+
22+
def test_train10(self):
23+
instances, max_label_value = self.check_reader(
24+
paddle.v2.dataset.cifar.train10())
25+
self.assertEqual(instances, 50000)
26+
self.assertEqual(max_label_value, 9)
27+
28+
def test_test100(self):
29+
instances, max_label_value = self.check_reader(
30+
paddle.v2.dataset.cifar.test100())
31+
self.assertEqual(instances, 10000)
32+
self.assertEqual(max_label_value, 99)
33+
34+
def test_train100(self):
35+
instances, max_label_value = self.check_reader(
36+
paddle.v2.dataset.cifar.train100())
37+
self.assertEqual(instances, 50000)
38+
self.assertEqual(max_label_value, 99)
39+
40+
41+
if __name__ == '__main__':
42+
unittest.main()

python/paddle/v2/dataset/tests/mnist_test.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,25 @@
55
class TestMNIST(unittest.TestCase):
66
def check_reader(self, reader):
77
sum = 0
8-
for l in reader:
8+
label = 0
9+
for l in reader():
910
self.assertEqual(l[0].size, 784)
10-
self.assertEqual(l[1].size, 1)
11-
self.assertLess(l[1], 10)
12-
self.assertGreaterEqual(l[1], 0)
11+
if l[1] > label:
12+
label = l[1]
1313
sum += 1
14-
return sum
14+
return sum, label
1515

1616
def test_train(self):
17-
self.assertEqual(
18-
self.check_reader(paddle.v2.dataset.mnist.train()), 60000)
17+
instances, max_label_value = self.check_reader(
18+
paddle.v2.dataset.mnist.train())
19+
self.assertEqual(instances, 60000)
20+
self.assertEqual(max_label_value, 9)
1921

2022
def test_test(self):
21-
self.assertEqual(
22-
self.check_reader(paddle.v2.dataset.mnist.test()), 10000)
23+
instances, max_label_value = self.check_reader(
24+
paddle.v2.dataset.mnist.test())
25+
self.assertEqual(instances, 10000)
26+
self.assertEqual(max_label_value, 9)
2327

2428

2529
if __name__ == '__main__':

0 commit comments

Comments
 (0)