|
1 | 1 | """
|
2 |
| -CIFAR Dataset. |
3 |
| -
|
4 |
| -URL: https://www.cs.toronto.edu/~kriz/cifar.html |
5 |
| -
|
6 |
| -the default train_creator, test_creator used for CIFAR-10 dataset. |
| 2 | +CIFAR dataset: https://www.cs.toronto.edu/~kriz/cifar.html |
7 | 3 | """
|
8 | 4 | import cPickle
|
9 | 5 | import itertools
|
10 |
| -import tarfile |
11 |
| - |
12 | 6 | import numpy
|
| 7 | +import paddle.v2.dataset.common |
| 8 | +import tarfile |
13 | 9 |
|
14 |
| -from common import download |
15 |
| - |
16 |
| -__all__ = [ |
17 |
| - 'cifar_100_train_creator', 'cifar_100_test_creator', 'train_creator', |
18 |
| - 'test_creator' |
19 |
| -] |
| 10 | +__all__ = ['train100', 'test100', 'train10', 'test10'] |
20 | 11 |
|
21 |
| -CIFAR10_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' |
| 12 | +URL_PREFIX = 'https://www.cs.toronto.edu/~kriz/' |
| 13 | +CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz' |
22 | 14 | CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
|
23 |
| -CIFAR100_URL = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' |
| 15 | +CIFAR100_URL = URL_PREFIX + 'cifar-100-python.tar.gz' |
24 | 16 | CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
|
25 | 17 |
|
26 | 18 |
|
27 |
| -def __read_batch__(filename, sub_name): |
28 |
| - def reader(): |
29 |
| - def __read_one_batch_impl__(batch): |
30 |
| - data = batch['data'] |
31 |
| - labels = batch.get('labels', batch.get('fine_labels', None)) |
32 |
| - assert labels is not None |
33 |
| - for sample, label in itertools.izip(data, labels): |
34 |
| - yield (sample / 255.0).astype(numpy.float32), int(label) |
| 19 | +def reader_creator(filename, sub_name): |
| 20 | + def read_batch(batch): |
| 21 | + data = batch['data'] |
| 22 | + labels = batch.get('labels', batch.get('fine_labels', None)) |
| 23 | + assert labels is not None |
| 24 | + for sample, label in itertools.izip(data, labels): |
| 25 | + yield (sample / 255.0).astype(numpy.float32), int(label) |
35 | 26 |
|
| 27 | + def reader(): |
36 | 28 | with tarfile.open(filename, mode='r') as f:
|
37 | 29 | names = (each_item.name for each_item in f
|
38 | 30 | if sub_name in each_item.name)
|
39 | 31 |
|
40 | 32 | for name in names:
|
41 | 33 | batch = cPickle.load(f.extractfile(name))
|
42 |
| - for item in __read_one_batch_impl__(batch): |
| 34 | + for item in read_batch(batch): |
43 | 35 | yield item
|
44 | 36 |
|
45 | 37 | return reader
|
46 | 38 |
|
47 | 39 |
|
48 |
| -def cifar_100_train_creator(): |
49 |
| - fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5) |
50 |
| - return __read_batch__(fn, 'train') |
51 |
| - |
52 |
| - |
53 |
| -def cifar_100_test_creator(): |
54 |
| - fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5) |
55 |
| - return __read_batch__(fn, 'test') |
56 |
| - |
57 |
| - |
58 |
| -def train_creator(): |
59 |
| - """ |
60 |
| - Default train reader creator. Use CIFAR-10 dataset. |
61 |
| - """ |
62 |
| - fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5) |
63 |
| - return __read_batch__(fn, 'data_batch') |
| 40 | +def train100(): |
| 41 | + return reader_creator( |
| 42 | + paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5), |
| 43 | + 'train') |
64 | 44 |
|
65 | 45 |
|
66 |
| -def test_creator(): |
67 |
| - """ |
68 |
| - Default test reader creator. Use CIFAR-10 dataset. |
69 |
| - """ |
70 |
| - fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5) |
71 |
| - return __read_batch__(fn, 'test_batch') |
| 46 | +def test100(): |
| 47 | + return reader_creator( |
| 48 | + paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5), |
| 49 | + 'test') |
72 | 50 |
|
73 | 51 |
|
74 |
| -def unittest(): |
75 |
| - for _ in train_creator()(): |
76 |
| - pass |
77 |
| - for _ in test_creator()(): |
78 |
| - pass |
| 52 | +def train10(): |
| 53 | + return reader_creator( |
| 54 | + paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), |
| 55 | + 'data_batch') |
79 | 56 |
|
80 | 57 |
|
81 |
| -if __name__ == '__main__': |
82 |
| - unittest() |
| 58 | +def test10(): |
| 59 | + return reader_creator( |
| 60 | + paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), |
| 61 | + 'test_batch') |
0 commit comments