Skip to content

Commit 0c7a88a

Browse files
authored
Add random seed to MatrixTrainSet (#202)
1 parent 8184c02 commit 0c7a88a

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

cornac/data/trainset.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class MatrixTrainSet(TrainSet):
110110
The dictionary containing mapping from original ids to mapped ids of items.
111111
112112
seed: int, optional, default: None
113-
Random seed for reproduce data sampling.
113+
Random seed for reproducing data sampling.
114114
115115
"""
116116

@@ -179,7 +179,7 @@ def item_ppl_rank(self):
179179

180180
@classmethod
181181
def from_uir(cls, data, global_uid_map=None, global_iid_map=None,
182-
global_ui_set=None, verbose=False):
182+
global_ui_set=None, seed=None, verbose=False):
183183
"""Constructing TrainSet from triplet data.
184184
185185
Parameters
@@ -196,6 +196,9 @@ def from_uir(cls, data, global_uid_map=None, global_iid_map=None,
196196
global_ui_set: :obj:`set`, optional, default: None
197197
The global set of tuples (user, item). This helps avoiding duplicate observations.
198198
199+
seed: int, optional, default: None
200+
Random seed for reproducing data sampling.
201+
199202
verbose: bool, default: False
200203
The verbosity flag.
201204
@@ -258,7 +261,7 @@ def from_uir(cls, data, global_uid_map=None, global_iid_map=None,
258261
print('Min rating = {:.1f}'.format(min_rating))
259262
print('Global mean = {:.1f}'.format(global_mean))
260263

261-
return cls(uir_tuple, max_rating, min_rating, global_mean, uid_map, iid_map)
264+
return cls(uir_tuple, max_rating, min_rating, global_mean, uid_map, iid_map, seed=seed)
262265

263266
def num_batches(self, batch_size):
264267
return estimate_batches(len(self.uir_tuple[0]), batch_size)

cornac/eval_methods/base_method.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def _build_uir(self, train_data, test_data, val_data=None):
192192
if self.verbose:
193193
print('Building training set')
194194
self.train_set = MultimodalTrainSet.from_uir(
195-
train_data, self.global_uid_map, self.global_iid_map, global_ui_set, self.verbose)
195+
train_data, self.global_uid_map, self.global_iid_map, global_ui_set, self.seed, self.verbose)
196196

197197
if self.verbose:
198198
print('Building test set')

0 commit comments

Comments
 (0)