Skip to content

Commit 3d0a160

Browse files
authored
Update MatrixTrainSet constructor to pass kwargs to super (#203)
1 parent 0c7a88a commit 3d0a160

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

cornac/data/trainset.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def __init__(self, uir_tuple, max_rating, min_rating, global_mean,
121121
self.max_rating = max_rating
122122
self.min_rating = min_rating
123123
self.global_mean = global_mean
124-
self.seed = seed
124+
self.rng = get_rng(seed)
125125

126126
self.__csr_matrix = None
127127
self.__csc_matrix = None
@@ -283,7 +283,7 @@ def idx_iter(self, idx_range, batch_size=1, shuffle=False):
283283
"""
284284
indices = np.arange(idx_range)
285285
if shuffle:
286-
get_rng(self.seed).shuffle(indices)
286+
self.rng.shuffle(indices)
287287

288288
n_batches = estimate_batches(len(indices), batch_size)
289289
for b in range(n_batches):
@@ -337,9 +337,9 @@ def uij_iter(self, batch_size=1, shuffle=False):
337337
batch_pos_ratings = self.uir_tuple[2][batch_ids]
338338
batch_neg_items = np.zeros_like(batch_pos_items)
339339
for i, (user, pos_rating) in enumerate(zip(batch_users, batch_pos_ratings)):
340-
neg_item = get_rng(self.seed).randint(0, self.num_items)
340+
neg_item = self.rng.randint(0, self.num_items)
341341
while self.dok_matrix[user, neg_item] >= pos_rating:
342-
neg_item = get_rng(self.seed).randint(0, self.num_items)
342+
neg_item = self.rng.randint(0, self.num_items)
343343
batch_neg_items[i] = neg_item
344344
yield batch_users, batch_pos_items, batch_neg_items
345345

@@ -406,7 +406,7 @@ class MultimodalTrainSet(MatrixTrainSet):
406406
"""
407407

408408
def __init__(self, matrix, max_rating, min_rating, global_mean, uid_map, iid_map, **kwargs):
409-
super().__init__(matrix, max_rating, min_rating, global_mean, uid_map, iid_map)
409+
super().__init__(matrix, max_rating, min_rating, global_mean, uid_map, iid_map, **kwargs)
410410
self.add_modalities(**kwargs)
411411

412412
def add_modalities(self, **kwargs):

0 commit comments

Comments
 (0)