Skip to content

Commit dfe6a2c

Browse files
authored
Merge pull request #1466 from reyoung/feature/movielens_data
Add MovieLens DataSet
2 parents 283e82f + de9012a commit dfe6a2c

File tree

3 files changed

+153
-32
lines changed

3 files changed

+153
-32
lines changed

python/paddle/v2/dataset/cifar.py

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,14 @@
55
66
the default train_creator, test_creator used for CIFAR-10 dataset.
77
"""
8-
from config import DATA_HOME
9-
import os
10-
import hashlib
11-
import urllib2
12-
import shutil
13-
import tarfile
148
import cPickle
159
import itertools
10+
import tarfile
11+
1612
import numpy
1713

14+
from config import download
15+
1816
__all__ = [
1917
'cifar_100_train_creator', 'cifar_100_test_creator', 'train_creator',
2018
'test_creator'
@@ -47,31 +45,6 @@ def __read_one_batch_impl__(batch):
4745
return reader
4846

4947

50-
def download(url, md5):
51-
filename = os.path.split(url)[-1]
52-
assert DATA_HOME is not None
53-
filepath = os.path.join(DATA_HOME, md5)
54-
if not os.path.exists(filepath):
55-
os.makedirs(filepath)
56-
__full_file__ = os.path.join(filepath, filename)
57-
58-
def __file_ok__():
59-
if not os.path.exists(__full_file__):
60-
return False
61-
md5_hash = hashlib.md5()
62-
with open(__full_file__, 'rb') as f:
63-
for chunk in iter(lambda: f.read(4096), b""):
64-
md5_hash.update(chunk)
65-
66-
return md5_hash.hexdigest() == md5
67-
68-
while not __file_ok__():
69-
response = urllib2.urlopen(url)
70-
with open(__full_file__, mode='wb') as of:
71-
shutil.copyfileobj(fsrc=response, fdst=of)
72-
return __full_file__
73-
74-
7548
def cifar_100_train_creator():
7649
fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5)
7750
return __read_batch__(fn, 'train')

python/paddle/v2/dataset/config.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,36 @@
1+
import hashlib
12
import os
3+
import shutil
4+
import urllib2
25

3-
__all__ = ['DATA_HOME']
6+
__all__ = ['DATA_HOME', 'download']
47

58
DATA_HOME = os.path.expanduser('~/.cache/paddle_data_set')
69

710
if not os.path.exists(DATA_HOME):
811
os.makedirs(DATA_HOME)
12+
13+
14+
def download(url, md5):
15+
filename = os.path.split(url)[-1]
16+
assert DATA_HOME is not None
17+
filepath = os.path.join(DATA_HOME, md5)
18+
if not os.path.exists(filepath):
19+
os.makedirs(filepath)
20+
__full_file__ = os.path.join(filepath, filename)
21+
22+
def __file_ok__():
23+
if not os.path.exists(__full_file__):
24+
return False
25+
md5_hash = hashlib.md5()
26+
with open(__full_file__, 'rb') as f:
27+
for chunk in iter(lambda: f.read(4096), b""):
28+
md5_hash.update(chunk)
29+
30+
return md5_hash.hexdigest() == md5
31+
32+
while not __file_ok__():
33+
response = urllib2.urlopen(url)
34+
with open(__full_file__, mode='wb') as of:
35+
shutil.copyfileobj(fsrc=response, fdst=of)
36+
return __full_file__
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import zipfile
2+
from config import download
3+
import re
4+
import random
5+
import functools
6+
7+
__all__ = ['train_creator', 'test_creator']
8+
9+
10+
class MovieInfo(object):
11+
def __init__(self, index, categories, title):
12+
self.index = int(index)
13+
self.categories = categories
14+
self.title = title
15+
16+
def value(self):
17+
return [
18+
self.index, [CATEGORIES_DICT[c] for c in self.categories],
19+
[MOVIE_TITLE_DICT[w.lower()] for w in self.title.split()]
20+
]
21+
22+
23+
class UserInfo(object):
24+
def __init__(self, index, gender, age, job_id):
25+
self.index = int(index)
26+
self.is_male = gender == 'M'
27+
self.age = [1, 18, 25, 35, 45, 50, 56].index(int(age))
28+
self.job_id = int(job_id)
29+
30+
def value(self):
31+
return [self.index, 0 if self.is_male else 1, self.age, self.job_id]
32+
33+
34+
MOVIE_INFO = None
35+
MOVIE_TITLE_DICT = None
36+
CATEGORIES_DICT = None
37+
USER_INFO = None
38+
39+
40+
def __initialize_meta_info__():
41+
fn = download(
42+
url='http://files.grouplens.org/datasets/movielens/ml-1m.zip',
43+
md5='c4d9eecfca2ab87c1945afe126590906')
44+
global MOVIE_INFO
45+
if MOVIE_INFO is None:
46+
pattern = re.compile(r'^(.*)\((\d+)\)$')
47+
with zipfile.ZipFile(file=fn) as package:
48+
for info in package.infolist():
49+
assert isinstance(info, zipfile.ZipInfo)
50+
MOVIE_INFO = dict()
51+
title_word_set = set()
52+
categories_set = set()
53+
with package.open('ml-1m/movies.dat') as movie_file:
54+
for i, line in enumerate(movie_file):
55+
movie_id, title, categories = line.strip().split('::')
56+
categories = categories.split('|')
57+
for c in categories:
58+
categories_set.add(c)
59+
title = pattern.match(title).group(1)
60+
MOVIE_INFO[int(movie_id)] = MovieInfo(
61+
index=movie_id, categories=categories, title=title)
62+
for w in title.split():
63+
title_word_set.add(w.lower())
64+
65+
global MOVIE_TITLE_DICT
66+
MOVIE_TITLE_DICT = dict()
67+
for i, w in enumerate(title_word_set):
68+
MOVIE_TITLE_DICT[w] = i
69+
70+
global CATEGORIES_DICT
71+
CATEGORIES_DICT = dict()
72+
for i, c in enumerate(categories_set):
73+
CATEGORIES_DICT[c] = i
74+
75+
global USER_INFO
76+
USER_INFO = dict()
77+
with package.open('ml-1m/users.dat') as user_file:
78+
for line in user_file:
79+
uid, gender, age, job, _ = line.strip().split("::")
80+
USER_INFO[int(uid)] = UserInfo(
81+
index=uid, gender=gender, age=age, job_id=job)
82+
return fn
83+
84+
85+
def __reader__(rand_seed=0, test_ratio=0.1, is_test=False):
86+
fn = __initialize_meta_info__()
87+
rand = random.Random(x=rand_seed)
88+
with zipfile.ZipFile(file=fn) as package:
89+
with package.open('ml-1m/ratings.dat') as rating:
90+
for line in rating:
91+
if (rand.random() < test_ratio) == is_test:
92+
uid, mov_id, rating, _ = line.strip().split("::")
93+
uid = int(uid)
94+
mov_id = int(mov_id)
95+
rating = float(rating) * 2 - 5.0
96+
97+
mov = MOVIE_INFO[mov_id]
98+
usr = USER_INFO[uid]
99+
yield usr.value() + mov.value() + [[rating]]
100+
101+
102+
def __reader_creator__(**kwargs):
103+
return lambda: __reader__(**kwargs)
104+
105+
106+
train_creator = functools.partial(__reader_creator__, is_test=False)
107+
test_creator = functools.partial(__reader_creator__, is_test=True)
108+
109+
110+
def unittest():
111+
for train_count, _ in enumerate(train_creator()()):
112+
pass
113+
for test_count, _ in enumerate(test_creator()()):
114+
pass
115+
116+
print train_count, test_count
117+
118+
119+
if __name__ == '__main__':
120+
unittest()

0 commit comments

Comments
 (0)