Skip to content

Commit bfc3310

Browse files
authored
Merge pull request #1687 from Yancey1989/dataset_cache_api
Add download api for dataset
2 parents 0b59be2 + 14eb5b8 commit bfc3310

File tree

10 files changed

+68
-19
lines changed

10 files changed

+68
-19
lines changed

python/paddle/v2/dataset/cifar.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import cPickle
2121
import itertools
2222
import numpy
23-
import paddle.v2.dataset.common
23+
from common import download
2424
import tarfile
2525

2626
__all__ = ['train100', 'test100', 'train10', 'test10']
@@ -55,23 +55,23 @@ def reader():
5555

5656
def train100():
5757
return reader_creator(
58-
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
59-
'train')
58+
download(CIFAR100_URL, 'cifar', CIFAR100_MD5), 'train')
6059

6160

6261
def test100():
63-
return reader_creator(
64-
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
65-
'test')
62+
return reader_creator(download(CIFAR100_URL, 'cifar', CIFAR100_MD5), 'test')
6663

6764

6865
def train10():
6966
return reader_creator(
70-
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
71-
'data_batch')
67+
download(CIFAR10_URL, 'cifar', CIFAR10_MD5), 'data_batch')
7268

7369

7470
def test10():
7571
return reader_creator(
76-
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
77-
'test_batch')
72+
download(CIFAR10_URL, 'cifar', CIFAR10_MD5), 'test_batch')
73+
74+
75+
def fetch():
76+
download(CIFAR10_URL, 'cifar', CIFAR10_MD5)
77+
download(CIFAR100_URL, 'cifar', CIFAR100_MD5)

python/paddle/v2/dataset/common.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import os
1818
import shutil
1919
import sys
20+
import importlib
21+
import paddle.v2.dataset
2022

2123
__all__ = ['DATA_HOME', 'download', 'md5file']
2224

@@ -69,3 +71,13 @@ def dict_add(a_dict, ele):
6971
a_dict[ele] += 1
7072
else:
7173
a_dict[ele] = 1
74+
75+
76+
def fetch_all():
77+
for module_name in filter(lambda x: not x.startswith("__"),
78+
dir(paddle.v2.dataset)):
79+
if "fetch" in dir(
80+
importlib.import_module("paddle.v2.dataset.%s" % module_name)):
81+
getattr(
82+
importlib.import_module("paddle.v2.dataset.%s" % module_name),
83+
"fetch")()

python/paddle/v2/dataset/conll05.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,11 @@ def test():
196196
words_name='conll05st-release/test.wsj/words/test.wsj.words.gz',
197197
props_name='conll05st-release/test.wsj/props/test.wsj.props.gz')
198198
return reader_creator(reader, word_dict, verb_dict, label_dict)
199+
200+
201+
def fetch():
202+
download(WORDDICT_URL, 'conll05st', WORDDICT_MD5)
203+
download(VERBDICT_URL, 'conll05st', VERBDICT_MD5)
204+
download(TRGDICT_URL, 'conll05st', TRGDICT_MD5)
205+
download(EMB_URL, 'conll05st', EMB_MD5)
206+
download(DATA_URL, 'conll05st', DATA_MD5)

python/paddle/v2/dataset/imdb.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,7 @@ def test(word_idx):
123123
def word_dict():
124124
return build_dict(
125125
re.compile("aclImdb/((train)|(test))/((pos)|(neg))/.*\.txt$"), 150)
126+
127+
128+
def fetch():
129+
paddle.v2.dataset.common.download(URL, 'imdb', MD5)

python/paddle/v2/dataset/imikolov.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,7 @@ def train(word_idx, n):
8989

9090
def test(word_idx, n):
9191
return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n)
92+
93+
94+
def fetch():
95+
paddle.v2.dataset.common.download(URL, "imikolov", MD5)

python/paddle/v2/dataset/mnist.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,10 @@ def test():
106106
TEST_IMAGE_MD5),
107107
paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist',
108108
TEST_LABEL_MD5), 100)
109+
110+
111+
def fetch():
112+
paddle.v2.dataset.common.download(TRAIN_IMAGE_URL, 'mnist', TRAIN_IMAGE_MD5)
113+
paddle.v2.dataset.common.download(TRAIN_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
114+
paddle.v2.dataset.common.download(TEST_IMAGE_URL, 'mnist', TEST_IMAGE_MD5)
115+
paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)

python/paddle/v2/dataset/movielens.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030

3131
age_table = [1, 18, 25, 35, 45, 50, 56]
3232

33+
URL = 'http://files.grouplens.org/datasets/movielens/ml-1m.zip'
34+
MD5 = 'c4d9eecfca2ab87c1945afe126590906'
35+
3336

3437
class MovieInfo(object):
3538
def __init__(self, index, categories, title):
@@ -77,10 +80,7 @@ def __repr__(self):
7780

7881

7982
def __initialize_meta_info__():
80-
fn = download(
81-
url='http://files.grouplens.org/datasets/movielens/ml-1m.zip',
82-
module_name='movielens',
83-
md5sum='c4d9eecfca2ab87c1945afe126590906')
83+
fn = download(URL, "movielens", MD5)
8484
global MOVIE_INFO
8585
if MOVIE_INFO is None:
8686
pattern = re.compile(r'^(.*)\((\d+)\)$')
@@ -205,5 +205,9 @@ def unittest():
205205
print train_count, test_count
206206

207207

208+
def fetch():
209+
download(URL, "movielens", MD5)
210+
211+
208212
if __name__ == '__main__':
209213
unittest()

python/paddle/v2/dataset/sentiment.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,7 @@ def test():
125125
"""
126126
data_set = load_sentiment_data()
127127
return reader_creator(data_set[NUM_TRAINING_INSTANCES:])
128+
129+
130+
def fetch():
131+
nltk.download('movie_reviews', download_dir=common.DATA_HOME)

python/paddle/v2/dataset/uci_housing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,7 @@ def reader():
8989
yield d[:-1], d[-1:]
9090

9191
return reader
92+
93+
94+
def fetch():
95+
download(URL, 'uci_housing', MD5)

python/paddle/v2/dataset/wmt14.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"""
1717
import tarfile
1818

19-
import paddle.v2.dataset.common
19+
from paddle.v2.dataset.common import download
2020

2121
__all__ = ['train', 'test', 'build_dict']
2222

@@ -95,11 +95,13 @@ def reader():
9595

9696
def train(dict_size):
9797
return reader_creator(
98-
paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN),
99-
'train/train', dict_size)
98+
download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'train/train', dict_size)
10099

101100

102101
def test(dict_size):
103102
return reader_creator(
104-
paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN),
105-
'test/test', dict_size)
103+
download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'test/test', dict_size)
104+
105+
106+
def fetch():
107+
download(URL_TRAIN, 'wmt14', MD5_TRAIN)

0 commit comments

Comments
 (0)