@@ -121,7 +121,7 @@ def __init__(self, uir_tuple, max_rating, min_rating, global_mean,
121121 self .max_rating = max_rating
122122 self .min_rating = min_rating
123123 self .global_mean = global_mean
124- self .seed = seed
124+ self .rng = get_rng ( seed )
125125
126126 self .__csr_matrix = None
127127 self .__csc_matrix = None
@@ -283,7 +283,7 @@ def idx_iter(self, idx_range, batch_size=1, shuffle=False):
283283 """
284284 indices = np .arange (idx_range )
285285 if shuffle :
286- get_rng ( self .seed ) .shuffle (indices )
286+ self .rng .shuffle (indices )
287287
288288 n_batches = estimate_batches (len (indices ), batch_size )
289289 for b in range (n_batches ):
@@ -337,9 +337,9 @@ def uij_iter(self, batch_size=1, shuffle=False):
337337 batch_pos_ratings = self .uir_tuple [2 ][batch_ids ]
338338 batch_neg_items = np .zeros_like (batch_pos_items )
339339 for i , (user , pos_rating ) in enumerate (zip (batch_users , batch_pos_ratings )):
340- neg_item = get_rng ( self .seed ) .randint (0 , self .num_items )
340+ neg_item = self .rng .randint (0 , self .num_items )
341341 while self .dok_matrix [user , neg_item ] >= pos_rating :
342- neg_item = get_rng ( self .seed ) .randint (0 , self .num_items )
342+ neg_item = self .rng .randint (0 , self .num_items )
343343 batch_neg_items [i ] = neg_item
344344 yield batch_users , batch_pos_items , batch_neg_items
345345
@@ -406,7 +406,7 @@ class MultimodalTrainSet(MatrixTrainSet):
406406 """
407407
408408 def __init__ (self , matrix , max_rating , min_rating , global_mean , uid_map , iid_map , ** kwargs ):
409- super ().__init__ (matrix , max_rating , min_rating , global_mean , uid_map , iid_map )
409+ super ().__init__ (matrix , max_rating , min_rating , global_mean , uid_map , iid_map , ** kwargs )
410410 self .add_modalities (** kwargs )
411411
412412 def add_modalities (self , ** kwargs ):
0 commit comments