|
| 1 | +""" |
| 2 | +CIFAR Dataset. |
| 3 | +
|
| 4 | +URL: https://www.cs.toronto.edu/~kriz/cifar.html |
| 5 | +
|
| 6 | +the default train_creator, test_creator used for CIFAR-10 dataset. |
| 7 | +""" |
| 8 | +from config import DATA_HOME |
| 9 | +import os |
| 10 | +import hashlib |
| 11 | +import urllib2 |
| 12 | +import shutil |
| 13 | +import tarfile |
| 14 | +import cPickle |
| 15 | +import itertools |
| 16 | +import numpy |
| 17 | + |
| 18 | +__all__ = [ |
| 19 | + 'cifar_100_train_creator', 'cifar_100_test_creator', 'train_creator', |
| 20 | + 'test_creator' |
| 21 | +] |
| 22 | + |
| 23 | +CIFAR10_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' |
| 24 | +CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a' |
| 25 | +CIFAR100_URL = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' |
| 26 | +CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85' |
| 27 | + |
| 28 | + |
| 29 | +def __read_batch__(filename, sub_name): |
| 30 | + def reader(): |
| 31 | + def __read_one_batch_impl__(batch): |
| 32 | + data = batch['data'] |
| 33 | + labels = batch.get('labels', batch.get('fine_labels', None)) |
| 34 | + assert labels is not None |
| 35 | + for sample, label in itertools.izip(data, labels): |
| 36 | + yield (sample / 255.0).astype(numpy.float32), int(label) |
| 37 | + |
| 38 | + with tarfile.open(filename, mode='r') as f: |
| 39 | + names = (each_item.name for each_item in f |
| 40 | + if sub_name in each_item.name) |
| 41 | + |
| 42 | + for name in names: |
| 43 | + batch = cPickle.load(f.extractfile(name)) |
| 44 | + for item in __read_one_batch_impl__(batch): |
| 45 | + yield item |
| 46 | + |
| 47 | + return reader |
| 48 | + |
| 49 | + |
| 50 | +def download(url, md5): |
| 51 | + filename = os.path.split(url)[-1] |
| 52 | + assert DATA_HOME is not None |
| 53 | + filepath = os.path.join(DATA_HOME, md5) |
| 54 | + if not os.path.exists(filepath): |
| 55 | + os.makedirs(filepath) |
| 56 | + __full_file__ = os.path.join(filepath, filename) |
| 57 | + |
| 58 | + def __file_ok__(): |
| 59 | + if not os.path.exists(__full_file__): |
| 60 | + return False |
| 61 | + md5_hash = hashlib.md5() |
| 62 | + with open(__full_file__, 'rb') as f: |
| 63 | + for chunk in iter(lambda: f.read(4096), b""): |
| 64 | + md5_hash.update(chunk) |
| 65 | + |
| 66 | + return md5_hash.hexdigest() == md5 |
| 67 | + |
| 68 | + while not __file_ok__(): |
| 69 | + response = urllib2.urlopen(url) |
| 70 | + with open(__full_file__, mode='wb') as of: |
| 71 | + shutil.copyfileobj(fsrc=response, fdst=of) |
| 72 | + return __full_file__ |
| 73 | + |
| 74 | + |
| 75 | +def cifar_100_train_creator(): |
| 76 | + fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5) |
| 77 | + return __read_batch__(fn, 'train') |
| 78 | + |
| 79 | + |
| 80 | +def cifar_100_test_creator(): |
| 81 | + fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5) |
| 82 | + return __read_batch__(fn, 'test') |
| 83 | + |
| 84 | + |
| 85 | +def train_creator(): |
| 86 | + """ |
| 87 | + Default train reader creator. Use CIFAR-10 dataset. |
| 88 | + """ |
| 89 | + fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5) |
| 90 | + return __read_batch__(fn, 'data_batch') |
| 91 | + |
| 92 | + |
| 93 | +def test_creator(): |
| 94 | + """ |
| 95 | + Default test reader creator. Use CIFAR-10 dataset. |
| 96 | + """ |
| 97 | + fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5) |
| 98 | + return __read_batch__(fn, 'test_batch') |
| 99 | + |
| 100 | + |
| 101 | +def unittest(): |
| 102 | + for _ in train_creator()(): |
| 103 | + pass |
| 104 | + for _ in test_creator()(): |
| 105 | + pass |
| 106 | + |
| 107 | + |
| 108 | +if __name__ == '__main__': |
| 109 | + unittest() |
0 commit comments