Skip to content

Commit 283e82f

Browse files
authored
Merge pull request #1450 from reyoung/feature/cifar_dataset
Feature/cifar dataset
2 parents 2f60406 + 0bcc4d4 commit 283e82f

File tree

1 file changed

+109
-0
lines changed

1 file changed

+109
-0
lines changed

python/paddle/v2/dataset/cifar.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""
2+
CIFAR Dataset.
3+
4+
URL: https://www.cs.toronto.edu/~kriz/cifar.html
5+
6+
the default train_creator, test_creator used for CIFAR-10 dataset.
7+
"""
8+
from config import DATA_HOME
9+
import os
10+
import hashlib
11+
import urllib2
12+
import shutil
13+
import tarfile
14+
import cPickle
15+
import itertools
16+
import numpy
17+
18+
__all__ = [
19+
'cifar_100_train_creator', 'cifar_100_test_creator', 'train_creator',
20+
'test_creator'
21+
]
22+
23+
CIFAR10_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
24+
CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
25+
CIFAR100_URL = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
26+
CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
27+
28+
29+
def __read_batch__(filename, sub_name):
30+
def reader():
31+
def __read_one_batch_impl__(batch):
32+
data = batch['data']
33+
labels = batch.get('labels', batch.get('fine_labels', None))
34+
assert labels is not None
35+
for sample, label in itertools.izip(data, labels):
36+
yield (sample / 255.0).astype(numpy.float32), int(label)
37+
38+
with tarfile.open(filename, mode='r') as f:
39+
names = (each_item.name for each_item in f
40+
if sub_name in each_item.name)
41+
42+
for name in names:
43+
batch = cPickle.load(f.extractfile(name))
44+
for item in __read_one_batch_impl__(batch):
45+
yield item
46+
47+
return reader
48+
49+
50+
def download(url, md5):
51+
filename = os.path.split(url)[-1]
52+
assert DATA_HOME is not None
53+
filepath = os.path.join(DATA_HOME, md5)
54+
if not os.path.exists(filepath):
55+
os.makedirs(filepath)
56+
__full_file__ = os.path.join(filepath, filename)
57+
58+
def __file_ok__():
59+
if not os.path.exists(__full_file__):
60+
return False
61+
md5_hash = hashlib.md5()
62+
with open(__full_file__, 'rb') as f:
63+
for chunk in iter(lambda: f.read(4096), b""):
64+
md5_hash.update(chunk)
65+
66+
return md5_hash.hexdigest() == md5
67+
68+
while not __file_ok__():
69+
response = urllib2.urlopen(url)
70+
with open(__full_file__, mode='wb') as of:
71+
shutil.copyfileobj(fsrc=response, fdst=of)
72+
return __full_file__
73+
74+
75+
def cifar_100_train_creator():
76+
fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5)
77+
return __read_batch__(fn, 'train')
78+
79+
80+
def cifar_100_test_creator():
81+
fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5)
82+
return __read_batch__(fn, 'test')
83+
84+
85+
def train_creator():
86+
"""
87+
Default train reader creator. Use CIFAR-10 dataset.
88+
"""
89+
fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5)
90+
return __read_batch__(fn, 'data_batch')
91+
92+
93+
def test_creator():
94+
"""
95+
Default test reader creator. Use CIFAR-10 dataset.
96+
"""
97+
fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5)
98+
return __read_batch__(fn, 'test_batch')
99+
100+
101+
def unittest():
102+
for _ in train_creator()():
103+
pass
104+
for _ in test_creator()():
105+
pass
106+
107+
108+
if __name__ == '__main__':
109+
unittest()

0 commit comments

Comments
 (0)