Skip to content

Commit 7293c82

Browse files
committed
Merge branch 'feature/clean_mnist_v2' into feature/tester
2 parents cb9d156 + 559efcd commit 7293c82

File tree

17 files changed

+315
-126
lines changed

17 files changed

+315
-126
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ before_install:
5757
- if [[ "$JOB" == "PRE_COMMIT" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi
5858
# Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python
5959
# protobuf version.
60-
- pip install numpy wheel 'protobuf==3.1' sphinx recommonmark sphinx_rtd_theme virtualenv pre-commit requests==2.9.2 LinkChecker 'scikit-learn>=0.18.0' 'scipy>=0.18.0'
60+
- pip install numpy wheel 'protobuf==3.1' sphinx recommonmark sphinx_rtd_theme virtualenv pre-commit requests==2.9.2 LinkChecker
6161
script:
6262
- paddle/scripts/travis/main.sh
6363
notifications:

demo/mnist/api_train_v2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import numpy
21
import paddle.v2 as paddle
32

43

@@ -41,7 +40,7 @@ def event_handler(event):
4140
trainer.train(
4241
reader=paddle.reader.batched(
4342
paddle.reader.shuffle(
44-
paddle.dataset.mnist.train_creator(), buf_size=8192),
43+
paddle.dataset.mnist.train(), buf_size=8192),
4544
batch_size=32),
4645
event_handler=event_handler)
4746

python/paddle/v2/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import trainer
1919
import event
2020
import data_type
21+
import topology
2122
import data_feeder
2223
from . import dataset
2324
from . import reader
@@ -27,7 +28,8 @@
2728

2829
__all__ = [
2930
'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer',
30-
'event', 'data_type', 'attr', 'pooling', 'data_feeder', 'dataset', 'reader'
31+
'event', 'data_type', 'attr', 'pooling', 'data_feeder', 'dataset', 'reader',
32+
'topology'
3133
]
3234

3335

python/paddle/v2/data_feeder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class DataFeeder(DataProviderConverter):
2323
"""
2424
DataFeeder converts the data returned by paddle.reader into a data structure
2525
of Arguments which is defined in the API. The paddle.reader usually returns
26-
a list of mini-batch data entries. Each data entry in the list is one sampe.
26+
a list of mini-batch data entries. Each data entry in the list is one sample.
2727
Each sample is a list or a tuple with one feature or multiple features.
2828
DataFeeder converts this mini-batch data entries into Arguments in order
2929
to feed it to C++ interface.

python/paddle/v2/data_type.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
# limitations under the License.
1414

1515
from paddle.trainer.PyDataProvider2 import \
16-
InputType, dense_vector, sparse_binary_vector,\
16+
InputType, DataType, dense_vector, sparse_binary_vector,\
1717
sparse_vector, integer_value, integer_value_sequence
1818

1919
__all__ = [
20-
'InputType', 'dense_vector', 'sparse_binary_vector', 'sparse_vector',
21-
'integer_value', 'integer_value_sequence'
20+
'InputType', 'DataType', 'dense_vector', 'sparse_binary_vector',
21+
'sparse_vector', 'integer_value', 'integer_value_sequence'
2222
]

python/paddle/v2/dataset/cifar.py

Lines changed: 32 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,61 @@
11
"""
2-
CIFAR Dataset.
3-
4-
URL: https://www.cs.toronto.edu/~kriz/cifar.html
5-
6-
the default train_creator, test_creator used for CIFAR-10 dataset.
2+
CIFAR dataset: https://www.cs.toronto.edu/~kriz/cifar.html
73
"""
84
import cPickle
95
import itertools
10-
import tarfile
11-
126
import numpy
7+
import paddle.v2.dataset.common
8+
import tarfile
139

14-
from common import download
15-
16-
__all__ = [
17-
'cifar_100_train_creator', 'cifar_100_test_creator', 'train_creator',
18-
'test_creator'
19-
]
10+
__all__ = ['train100', 'test100', 'train10', 'test10']
2011

21-
CIFAR10_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
12+
URL_PREFIX = 'https://www.cs.toronto.edu/~kriz/'
13+
CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz'
2214
CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
23-
CIFAR100_URL = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
15+
CIFAR100_URL = URL_PREFIX + 'cifar-100-python.tar.gz'
2416
CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
2517

2618

27-
def __read_batch__(filename, sub_name):
28-
def reader():
29-
def __read_one_batch_impl__(batch):
30-
data = batch['data']
31-
labels = batch.get('labels', batch.get('fine_labels', None))
32-
assert labels is not None
33-
for sample, label in itertools.izip(data, labels):
34-
yield (sample / 255.0).astype(numpy.float32), int(label)
19+
def reader_creator(filename, sub_name):
20+
def read_batch(batch):
21+
data = batch['data']
22+
labels = batch.get('labels', batch.get('fine_labels', None))
23+
assert labels is not None
24+
for sample, label in itertools.izip(data, labels):
25+
yield (sample / 255.0).astype(numpy.float32), int(label)
3526

27+
def reader():
3628
with tarfile.open(filename, mode='r') as f:
3729
names = (each_item.name for each_item in f
3830
if sub_name in each_item.name)
3931

4032
for name in names:
4133
batch = cPickle.load(f.extractfile(name))
42-
for item in __read_one_batch_impl__(batch):
34+
for item in read_batch(batch):
4335
yield item
4436

4537
return reader
4638

4739

48-
def cifar_100_train_creator():
49-
fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5)
50-
return __read_batch__(fn, 'train')
51-
52-
53-
def cifar_100_test_creator():
54-
fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5)
55-
return __read_batch__(fn, 'test')
56-
57-
58-
def train_creator():
59-
"""
60-
Default train reader creator. Use CIFAR-10 dataset.
61-
"""
62-
fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5)
63-
return __read_batch__(fn, 'data_batch')
40+
def train100():
41+
return reader_creator(
42+
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
43+
'train')
6444

6545

66-
def test_creator():
67-
"""
68-
Default test reader creator. Use CIFAR-10 dataset.
69-
"""
70-
fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5)
71-
return __read_batch__(fn, 'test_batch')
46+
def test100():
47+
return reader_creator(
48+
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
49+
'test')
7250

7351

74-
def unittest():
75-
for _ in train_creator()():
76-
pass
77-
for _ in test_creator()():
78-
pass
52+
def train10():
53+
return reader_creator(
54+
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
55+
'data_batch')
7956

8057

81-
if __name__ == '__main__':
82-
unittest()
58+
def test10():
59+
return reader_creator(
60+
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
61+
'test_batch')

python/paddle/v2/dataset/common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def download(url, module_name, md5sum):
2727

2828
filename = os.path.join(dirname, url.split('/')[-1])
2929
if not (os.path.exists(filename) and md5file(filename) == md5sum):
30-
# If file doesn't exist or MD5 doesn't match, then download.
3130
r = requests.get(url, stream=True)
3231
with open(filename, 'w') as f:
3332
shutil.copyfileobj(r.raw, f)

python/paddle/v2/dataset/mnist.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
"""
2+
MNIST dataset.
3+
"""
14
import paddle.v2.dataset.common
25
import subprocess
36
import numpy
4-
7+
import platform
58
__all__ = ['train', 'test']
69

710
URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/'
8-
911
TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
1012
TEST_IMAGE_MD5 = '25e3cc63507ef6e98d5dc541e8672bb6'
1113
TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
@@ -18,12 +20,19 @@
1820

1921
def reader_creator(image_filename, label_filename, buffer_size):
2022
def reader():
23+
if platform.system() == 'Darwin':
24+
zcat_cmd = 'gzcat'
25+
elif platform.system() == 'Linux':
26+
zcat_cmd = 'zcat'
27+
else:
28+
raise NotImplementedError()
29+
2130
# According to http://stackoverflow.com/a/38061619/724872, we
2231
# cannot use standard package gzip here.
23-
m = subprocess.Popen(["zcat", image_filename], stdout=subprocess.PIPE)
32+
m = subprocess.Popen([zcat_cmd, image_filename], stdout=subprocess.PIPE)
2433
m.stdout.read(16) # skip some magic bytes
2534

26-
l = subprocess.Popen(["zcat", label_filename], stdout=subprocess.PIPE)
35+
l = subprocess.Popen([zcat_cmd, label_filename], stdout=subprocess.PIPE)
2736
l.stdout.read(8) # skip some magic bytes
2837

2938
while True:
@@ -40,12 +49,12 @@ def reader():
4049
images = images / 255.0 * 2.0 - 1.0
4150

4251
for i in xrange(buffer_size):
43-
yield images[i, :], labels[i]
52+
yield images[i, :], int(labels[i])
4453

4554
m.terminate()
4655
l.terminate()
4756

48-
return reader()
57+
return reader
4958

5059

5160
def train():
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import paddle.v2.dataset.cifar
2+
import unittest
3+
4+
5+
class TestCIFAR(unittest.TestCase):
6+
def check_reader(self, reader):
7+
sum = 0
8+
label = 0
9+
for l in reader():
10+
self.assertEqual(l[0].size, 3072)
11+
if l[1] > label:
12+
label = l[1]
13+
sum += 1
14+
return sum, label
15+
16+
def test_test10(self):
17+
instances, max_label_value = self.check_reader(
18+
paddle.v2.dataset.cifar.test10())
19+
self.assertEqual(instances, 10000)
20+
self.assertEqual(max_label_value, 9)
21+
22+
def test_train10(self):
23+
instances, max_label_value = self.check_reader(
24+
paddle.v2.dataset.cifar.train10())
25+
self.assertEqual(instances, 50000)
26+
self.assertEqual(max_label_value, 9)
27+
28+
def test_test100(self):
29+
instances, max_label_value = self.check_reader(
30+
paddle.v2.dataset.cifar.test100())
31+
self.assertEqual(instances, 10000)
32+
self.assertEqual(max_label_value, 99)
33+
34+
def test_train100(self):
35+
instances, max_label_value = self.check_reader(
36+
paddle.v2.dataset.cifar.train100())
37+
self.assertEqual(instances, 50000)
38+
self.assertEqual(max_label_value, 99)
39+
40+
41+
if __name__ == '__main__':
42+
unittest.main()

python/paddle/v2/dataset/tests/mnist_test.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,25 @@
55
class TestMNIST(unittest.TestCase):
66
def check_reader(self, reader):
77
sum = 0
8-
for l in reader:
8+
label = 0
9+
for l in reader():
910
self.assertEqual(l[0].size, 784)
10-
self.assertEqual(l[1].size, 1)
11-
self.assertLess(l[1], 10)
12-
self.assertGreaterEqual(l[1], 0)
11+
if l[1] > label:
12+
label = l[1]
1313
sum += 1
14-
return sum
14+
return sum, label
1515

1616
def test_train(self):
17-
self.assertEqual(
18-
self.check_reader(paddle.v2.dataset.mnist.train()), 60000)
17+
instances, max_label_value = self.check_reader(
18+
paddle.v2.dataset.mnist.train())
19+
self.assertEqual(instances, 60000)
20+
self.assertEqual(max_label_value, 9)
1921

2022
def test_test(self):
21-
self.assertEqual(
22-
self.check_reader(paddle.v2.dataset.mnist.test()), 10000)
23+
instances, max_label_value = self.check_reader(
24+
paddle.v2.dataset.mnist.test())
25+
self.assertEqual(instances, 10000)
26+
self.assertEqual(max_label_value, 9)
2327

2428

2529
if __name__ == '__main__':

0 commit comments

Comments
 (0)