@@ -34,10 +34,8 @@ class GeneralDataLoader(AbstractDataLoader):
3434 """
3535 dl_type = DataLoaderType .ORIGIN
3636
37- def __init__ (self , config , dataset ,
38- batch_size = 1 , dl_format = InputType .POINTWISE , shuffle = False ):
39- super ().__init__ (config , dataset ,
40- batch_size = batch_size , dl_format = dl_format , shuffle = shuffle )
37+ def __init__ (self , config , dataset , batch_size = 1 , dl_format = InputType .POINTWISE , shuffle = False ):
38+ super ().__init__ (config , dataset , batch_size = batch_size , dl_format = dl_format , shuffle = shuffle )
4139
4240 @property
4341 def pr_end (self ):
@@ -47,7 +45,7 @@ def _shuffle(self):
4745 self .dataset .shuffle ()
4846
4947 def _next_batch_data (self ):
50- cur_data = self .dataset [self .pr : self .pr + self .step ]
48+ cur_data = self .dataset [self .pr :self .pr + self .step ]
5149 self .pr += self .step
5250 return cur_data
5351
@@ -70,14 +68,16 @@ class GeneralNegSampleDataLoader(NegSampleByMixin, AbstractDataLoader):
7068 shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``.
7169 """
7270
73- def __init__ (self , config , dataset , sampler , neg_sample_args ,
74- batch_size = 1 , dl_format = InputType .POINTWISE , shuffle = False ):
71+ def __init__ (
72+ self , config , dataset , sampler , neg_sample_args , batch_size = 1 , dl_format = InputType .POINTWISE , shuffle = False
73+ ):
7574 self .uid_field = dataset .uid_field
7675 self .iid_field = dataset .iid_field
7776 self .uid_list , self .uid2index , self .uid2items_num = None , None , None
7877
79- super ().__init__ (config , dataset , sampler , neg_sample_args ,
80- batch_size = batch_size , dl_format = dl_format , shuffle = shuffle )
78+ super ().__init__ (
79+ config , dataset , sampler , neg_sample_args , batch_size = batch_size , dl_format = dl_format , shuffle = shuffle
80+ )
8181
8282 def setup (self ):
8383 if self .user_inter_in_one_batch :
@@ -132,7 +132,7 @@ def _shuffle(self):
132132
133133 def _next_batch_data (self ):
134134 if self .user_inter_in_one_batch :
135- uid_list = self .uid_list [self .pr : self .pr + self .step ]
135+ uid_list = self .uid_list [self .pr :self .pr + self .step ]
136136 data_list = []
137137 for uid in uid_list :
138138 index = self .uid2index [uid ]
@@ -144,7 +144,7 @@ def _next_batch_data(self):
144144 self .pr += self .step
145145 return cur_data
146146 else :
147- cur_data = self ._neg_sampling (self .dataset [self .pr : self .pr + self .step ])
147+ cur_data = self ._neg_sampling (self .dataset [self .pr :self .pr + self .step ])
148148 self .pr += self .step
149149 return cur_data
150150
@@ -167,7 +167,7 @@ def _neg_sample_by_point_wise_sampling(self, inter_feat, neg_iids):
167167 new_data [self .iid_field ][pos_inter_num :] = neg_iids
168168 new_data = self .dataset .join (new_data )
169169 labels = torch .zeros (pos_inter_num * self .times )
170- labels [: pos_inter_num ] = 1.0
170+ labels [:pos_inter_num ] = 1.0
171171 new_data .update (Interaction ({self .label_field : labels }))
172172 return new_data
173173
@@ -203,8 +203,9 @@ class GeneralFullDataLoader(NegSampleMixin, AbstractDataLoader):
203203 """
204204 dl_type = DataLoaderType .FULL
205205
206- def __init__ (self , config , dataset , sampler , neg_sample_args ,
207- batch_size = 1 , dl_format = InputType .POINTWISE , shuffle = False ):
206+ def __init__ (
207+ self , config , dataset , sampler , neg_sample_args , batch_size = 1 , dl_format = InputType .POINTWISE , shuffle = False
208+ ):
208209 if neg_sample_args ['strategy' ] != 'full' :
209210 raise ValueError ('neg_sample strategy in GeneralFullDataLoader() should be `full`' )
210211
@@ -232,8 +233,9 @@ def __init__(self, config, dataset, sampler, neg_sample_args,
232233 self .uid_list = torch .tensor (self .uid_list )
233234 self .user_df = dataset .join (Interaction ({uid_field : self .uid_list }))
234235
235- super ().__init__ (config , dataset , sampler , neg_sample_args ,
236- batch_size = batch_size , dl_format = dl_format , shuffle = shuffle )
236+ super ().__init__ (
237+ config , dataset , sampler , neg_sample_args , batch_size = batch_size , dl_format = dl_format , shuffle = shuffle
238+ )
237239
238240 def _set_user_property (self , uid , used_item , positive_item ):
239241 if uid is None :
@@ -260,7 +262,7 @@ def _shuffle(self):
260262 self .logger .warnning ('GeneralFullDataLoader can\' t shuffle' )
261263
262264 def _next_batch_data (self ):
263- user_df = self .user_df [self .pr : self .pr + self .step ]
265+ user_df = self .user_df [self .pr :self .pr + self .step ]
264266 cur_data = self ._neg_sampling (user_df )
265267 self .pr += self .step
266268 return cur_data
0 commit comments