|
| 1 | +import zipfile |
| 2 | +from config import download |
| 3 | +import re |
| 4 | +import random |
| 5 | +import functools |
| 6 | + |
| 7 | +__all__ = ['train_creator', 'test_creator'] |
| 8 | + |
| 9 | + |
| 10 | +class MovieInfo(object): |
| 11 | + def __init__(self, index, categories, title): |
| 12 | + self.index = int(index) |
| 13 | + self.categories = categories |
| 14 | + self.title = title |
| 15 | + |
| 16 | + def value(self): |
| 17 | + return [ |
| 18 | + self.index, [CATEGORIES_DICT[c] for c in self.categories], |
| 19 | + [MOVIE_TITLE_DICT[w.lower()] for w in self.title.split()] |
| 20 | + ] |
| 21 | + |
| 22 | + |
| 23 | +class UserInfo(object): |
| 24 | + def __init__(self, index, gender, age, job_id): |
| 25 | + self.index = int(index) |
| 26 | + self.is_male = gender == 'M' |
| 27 | + self.age = [1, 18, 25, 35, 45, 50, 56].index(int(age)) |
| 28 | + self.job_id = int(job_id) |
| 29 | + |
| 30 | + def value(self): |
| 31 | + return [self.index, 0 if self.is_male else 1, self.age, self.job_id] |
| 32 | + |
| 33 | + |
| 34 | +MOVIE_INFO = None |
| 35 | +MOVIE_TITLE_DICT = None |
| 36 | +CATEGORIES_DICT = None |
| 37 | +USER_INFO = None |
| 38 | + |
| 39 | + |
| 40 | +def __initialize_meta_info__(): |
| 41 | + fn = download( |
| 42 | + url='http://files.grouplens.org/datasets/movielens/ml-1m.zip', |
| 43 | + md5='c4d9eecfca2ab87c1945afe126590906') |
| 44 | + global MOVIE_INFO |
| 45 | + if MOVIE_INFO is None: |
| 46 | + pattern = re.compile(r'^(.*)\((\d+)\)$') |
| 47 | + with zipfile.ZipFile(file=fn) as package: |
| 48 | + for info in package.infolist(): |
| 49 | + assert isinstance(info, zipfile.ZipInfo) |
| 50 | + MOVIE_INFO = dict() |
| 51 | + title_word_set = set() |
| 52 | + categories_set = set() |
| 53 | + with package.open('ml-1m/movies.dat') as movie_file: |
| 54 | + for i, line in enumerate(movie_file): |
| 55 | + movie_id, title, categories = line.strip().split('::') |
| 56 | + categories = categories.split('|') |
| 57 | + for c in categories: |
| 58 | + categories_set.add(c) |
| 59 | + title = pattern.match(title).group(1) |
| 60 | + MOVIE_INFO[int(movie_id)] = MovieInfo( |
| 61 | + index=movie_id, categories=categories, title=title) |
| 62 | + for w in title.split(): |
| 63 | + title_word_set.add(w.lower()) |
| 64 | + |
| 65 | + global MOVIE_TITLE_DICT |
| 66 | + MOVIE_TITLE_DICT = dict() |
| 67 | + for i, w in enumerate(title_word_set): |
| 68 | + MOVIE_TITLE_DICT[w] = i |
| 69 | + |
| 70 | + global CATEGORIES_DICT |
| 71 | + CATEGORIES_DICT = dict() |
| 72 | + for i, c in enumerate(categories_set): |
| 73 | + CATEGORIES_DICT[c] = i |
| 74 | + |
| 75 | + global USER_INFO |
| 76 | + USER_INFO = dict() |
| 77 | + with package.open('ml-1m/users.dat') as user_file: |
| 78 | + for line in user_file: |
| 79 | + uid, gender, age, job, _ = line.strip().split("::") |
| 80 | + USER_INFO[int(uid)] = UserInfo( |
| 81 | + index=uid, gender=gender, age=age, job_id=job) |
| 82 | + return fn |
| 83 | + |
| 84 | + |
| 85 | +def __reader__(rand_seed=0, test_ratio=0.1, is_test=False): |
| 86 | + fn = __initialize_meta_info__() |
| 87 | + rand = random.Random(x=rand_seed) |
| 88 | + with zipfile.ZipFile(file=fn) as package: |
| 89 | + with package.open('ml-1m/ratings.dat') as rating: |
| 90 | + for line in rating: |
| 91 | + if (rand.random() < test_ratio) == is_test: |
| 92 | + uid, mov_id, rating, _ = line.strip().split("::") |
| 93 | + uid = int(uid) |
| 94 | + mov_id = int(mov_id) |
| 95 | + rating = float(rating) * 2 - 5.0 |
| 96 | + |
| 97 | + mov = MOVIE_INFO[mov_id] |
| 98 | + usr = USER_INFO[uid] |
| 99 | + yield usr.value() + mov.value() + [[rating]] |
| 100 | + |
| 101 | + |
| 102 | +def __reader_creator__(**kwargs): |
| 103 | + return lambda: __reader__(**kwargs) |
| 104 | + |
| 105 | + |
| 106 | +train_creator = functools.partial(__reader_creator__, is_test=False) |
| 107 | +test_creator = functools.partial(__reader_creator__, is_test=True) |
| 108 | + |
| 109 | + |
| 110 | +def unittest(): |
| 111 | + for train_count, _ in enumerate(train_creator()()): |
| 112 | + pass |
| 113 | + for test_count, _ in enumerate(test_creator()()): |
| 114 | + pass |
| 115 | + |
| 116 | + print train_count, test_count |
| 117 | + |
| 118 | + |
| 119 | +if __name__ == '__main__': |
| 120 | + unittest() |
0 commit comments