23
23
import random
24
24
import functools
25
25
26
- __all__ = ['train_creator' , 'test_creator' ]
26
+ __all__ = [
27
+ 'train' , 'test' , 'get_movie_title_dict' , 'max_movie_id' , 'max_user_id' ,
28
+ 'age_table' , 'movie_categories' , 'max_job_id' , 'user_info' , 'movie_info'
29
+ ]
30
+
31
+ age_table = [1 , 18 , 25 , 35 , 45 , 50 , 56 ]
27
32
28
33
29
34
class MovieInfo (object ):
@@ -38,17 +43,32 @@ def value(self):
38
43
[MOVIE_TITLE_DICT [w .lower ()] for w in self .title .split ()]
39
44
]
40
45
46
+ def __str__ (self ):
47
+ return "<MovieInfo id(%d), title(%s), categories(%s)>" % (
48
+ self .index , self .title , self .categories )
49
+
50
+ def __repr__ (self ):
51
+ return self .__str__ ()
52
+
41
53
42
54
class UserInfo (object ):
43
55
def __init__ (self , index , gender , age , job_id ):
44
56
self .index = int (index )
45
57
self .is_male = gender == 'M'
46
- self .age = [ 1 , 18 , 25 , 35 , 45 , 50 , 56 ] .index (int (age ))
58
+ self .age = age_table .index (int (age ))
47
59
self .job_id = int (job_id )
48
60
49
61
def value (self ):
50
62
return [self .index , 0 if self .is_male else 1 , self .age , self .job_id ]
51
63
64
+ def __str__ (self ):
65
+ return "<UserInfo id(%d), gender(%s), age(%d), job(%d)>" % (
66
+ self .index , "M"
67
+ if self .is_male else "F" , age_table [self .age ], self .job_id )
68
+
69
+ def __repr__ (self ):
70
+ return str (self )
71
+
52
72
53
73
MOVIE_INFO = None
54
74
MOVIE_TITLE_DICT = None
@@ -59,7 +79,8 @@ def value(self):
59
79
def __initialize_meta_info__ ():
60
80
fn = download (
61
81
url = 'http://files.grouplens.org/datasets/movielens/ml-1m.zip' ,
62
- md5 = 'c4d9eecfca2ab87c1945afe126590906' )
82
+ module_name = 'movielens' ,
83
+ md5sum = 'c4d9eecfca2ab87c1945afe126590906' )
63
84
global MOVIE_INFO
64
85
if MOVIE_INFO is None :
65
86
pattern = re .compile (r'^(.*)\((\d+)\)$' )
@@ -122,14 +143,63 @@ def __reader_creator__(**kwargs):
122
143
return lambda : __reader__ (** kwargs )
123
144
124
145
125
- train_creator = functools .partial (__reader_creator__ , is_test = False )
126
- test_creator = functools .partial (__reader_creator__ , is_test = True )
146
+ train = functools .partial (__reader_creator__ , is_test = False )
147
+ test = functools .partial (__reader_creator__ , is_test = True )
148
+
149
+
150
+ def get_movie_title_dict ():
151
+ __initialize_meta_info__ ()
152
+ return MOVIE_TITLE_DICT
153
+
154
+
155
+ def __max_index_info__ (a , b ):
156
+ if a .index > b .index :
157
+ return a
158
+ else :
159
+ return b
160
+
161
+
162
+ def max_movie_id ():
163
+ __initialize_meta_info__ ()
164
+ return reduce (__max_index_info__ , MOVIE_INFO .viewvalues ()).index
165
+
166
+
167
+ def max_user_id ():
168
+ __initialize_meta_info__ ()
169
+ return reduce (__max_index_info__ , USER_INFO .viewvalues ()).index
170
+
171
+
172
+ def __max_job_id_impl__ (a , b ):
173
+ if a .job_id > b .job_id :
174
+ return a
175
+ else :
176
+ return b
177
+
178
+
179
+ def max_job_id ():
180
+ __initialize_meta_info__ ()
181
+ return reduce (__max_job_id_impl__ , USER_INFO .viewvalues ()).job_id
182
+
183
+
184
+ def movie_categories ():
185
+ __initialize_meta_info__ ()
186
+ return CATEGORIES_DICT
187
+
188
+
189
+ def user_info ():
190
+ __initialize_meta_info__ ()
191
+ return USER_INFO
192
+
193
+
194
+ def movie_info ():
195
+ __initialize_meta_info__ ()
196
+ return MOVIE_INFO
127
197
128
198
129
199
def unittest ():
130
- for train_count , _ in enumerate (train_creator ()()):
200
+ for train_count , _ in enumerate (train ()()):
131
201
pass
132
- for test_count , _ in enumerate (test_creator ()()):
202
+ for test_count , _ in enumerate (test ()()):
133
203
pass
134
204
135
205
print train_count , test_count
0 commit comments