Skip to content

Commit 4c4838b

Browse files
authored
Merge pull request #664 from chenyushuo/master
FORMAT: code format by yapf
2 parents ccc0acc + ae5c541 commit 4c4838b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+1254
-729
lines changed

recbole/config/configurator.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,16 @@ def _build_yaml_loader(self):
8989
loader = yaml.FullLoader
9090
loader.add_implicit_resolver(
9191
u'tag:yaml.org,2002:float',
92-
re.compile(u'''^(?:
92+
re.compile(
93+
u'''^(?:
9394
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
9495
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
9596
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
9697
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
9798
|[-+]?\\.(?:inf|Inf|INF)
98-
|\\.(?:nan|NaN|NAN))$''', re.X),
99-
list(u'-+0123456789.'))
99+
|\\.(?:nan|NaN|NAN))$''', re.X
100+
), list(u'-+0123456789.')
101+
)
100102
return loader
101103

102104
def _convert_config_dict(self, config_dict):
@@ -175,7 +177,8 @@ def _get_model_and_dataset(self, model, dataset):
175177
except KeyError:
176178
raise KeyError(
177179
'model need to be specified in at least one of the these ways: '
178-
'[model variable, config file, config dict, command line] ')
180+
'[model variable, config file, config dict, command line] '
181+
)
179182
if not isinstance(model, str):
180183
final_model_class = model
181184
final_model = model.__name__
@@ -187,8 +190,10 @@ def _get_model_and_dataset(self, model, dataset):
187190
try:
188191
final_dataset = self.external_config_dict['dataset']
189192
except KeyError:
190-
raise KeyError('dataset need to be specified in at least one of the these ways: '
191-
'[dataset variable, config file, config dict, command line] ')
193+
raise KeyError(
194+
'dataset need to be specified in at least one of the these ways: '
195+
'[dataset variable, config file, config dict, command line] '
196+
)
192197
else:
193198
final_dataset = dataset
194199

@@ -223,8 +228,9 @@ def _load_internal_config_dict(self, model, model_class, dataset):
223228
if os.path.isfile(file):
224229
config_dict = self._update_internal_config_dict(file)
225230
if file == dataset_init_file:
226-
self.parameters['Dataset'] += [key for key in config_dict.keys() if
227-
key not in self.parameters['Dataset']]
231+
self.parameters['Dataset'] += [
232+
key for key in config_dict.keys() if key not in self.parameters['Dataset']
233+
]
228234

229235
self.internal_config_dict['MODEL_TYPE'] = model_class.type
230236
if self.internal_config_dict['MODEL_TYPE'] == ModelType.GENERAL:
@@ -272,8 +278,7 @@ def _set_default_parameters(self):
272278
elif self.final_config_dict['loss_type'] in ['BPR']:
273279
self.final_config_dict['MODEL_INPUT_TYPE'] = InputType.PAIRWISE
274280
else:
275-
raise ValueError('Either Model has attr \'input_type\','
276-
'or arg \'loss_type\' should exist in config.')
281+
raise ValueError('Either Model has attr \'input_type\',' 'or arg \'loss_type\' should exist in config.')
277282

278283
eval_type = None
279284
for metric in self.final_config_dict['metrics']:
@@ -324,10 +329,10 @@ def __str__(self):
324329
args_info = ''
325330
for category in self.parameters:
326331
args_info += category + ' Hyper Parameters: \n'
327-
args_info += '\n'.join(
328-
["{}={}".format(arg, value)
329-
for arg, value in self.final_config_dict.items()
330-
if arg in self.parameters[category]])
332+
args_info += '\n'.join([
333+
"{}={}".format(arg, value) for arg, value in self.final_config_dict.items()
334+
if arg in self.parameters[category]
335+
])
331336
args_info += '\n\n'
332337
return args_info
333338

recbole/config/eval_setting.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
# @Author : Yupeng Hou, Yushuo Chen
88
# @Email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn
99

10-
1110
"""
1211
recbole.config.eval_setting
1312
################################

recbole/data/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
from recbole.data.utils import *
22

3-
43
__all__ = ['create_dataset', 'data_preparation']

recbole/data/dataloader/abstract_dataloader.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ class AbstractDataLoader(object):
4242
"""
4343
dl_type = None
4444

45-
def __init__(self, config, dataset,
46-
batch_size=1, dl_format=InputType.POINTWISE, shuffle=False):
45+
def __init__(self, config, dataset, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False):
4746
self.config = config
4847
self.logger = getLogger()
4948
self.dataset = dataset

recbole/data/dataloader/general_dataloader.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,8 @@ class GeneralDataLoader(AbstractDataLoader):
3434
"""
3535
dl_type = DataLoaderType.ORIGIN
3636

37-
def __init__(self, config, dataset,
38-
batch_size=1, dl_format=InputType.POINTWISE, shuffle=False):
39-
super().__init__(config, dataset,
40-
batch_size=batch_size, dl_format=dl_format, shuffle=shuffle)
37+
def __init__(self, config, dataset, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False):
38+
super().__init__(config, dataset, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle)
4139

4240
@property
4341
def pr_end(self):
@@ -47,7 +45,7 @@ def _shuffle(self):
4745
self.dataset.shuffle()
4846

4947
def _next_batch_data(self):
50-
cur_data = self.dataset[self.pr: self.pr + self.step]
48+
cur_data = self.dataset[self.pr:self.pr + self.step]
5149
self.pr += self.step
5250
return cur_data
5351

@@ -70,14 +68,16 @@ class GeneralNegSampleDataLoader(NegSampleByMixin, AbstractDataLoader):
7068
shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``.
7169
"""
7270

73-
def __init__(self, config, dataset, sampler, neg_sample_args,
74-
batch_size=1, dl_format=InputType.POINTWISE, shuffle=False):
71+
def __init__(
72+
self, config, dataset, sampler, neg_sample_args, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False
73+
):
7574
self.uid_field = dataset.uid_field
7675
self.iid_field = dataset.iid_field
7776
self.uid_list, self.uid2index, self.uid2items_num = None, None, None
7877

79-
super().__init__(config, dataset, sampler, neg_sample_args,
80-
batch_size=batch_size, dl_format=dl_format, shuffle=shuffle)
78+
super().__init__(
79+
config, dataset, sampler, neg_sample_args, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle
80+
)
8181

8282
def setup(self):
8383
if self.user_inter_in_one_batch:
@@ -132,7 +132,7 @@ def _shuffle(self):
132132

133133
def _next_batch_data(self):
134134
if self.user_inter_in_one_batch:
135-
uid_list = self.uid_list[self.pr: self.pr + self.step]
135+
uid_list = self.uid_list[self.pr:self.pr + self.step]
136136
data_list = []
137137
for uid in uid_list:
138138
index = self.uid2index[uid]
@@ -144,7 +144,7 @@ def _next_batch_data(self):
144144
self.pr += self.step
145145
return cur_data
146146
else:
147-
cur_data = self._neg_sampling(self.dataset[self.pr: self.pr + self.step])
147+
cur_data = self._neg_sampling(self.dataset[self.pr:self.pr + self.step])
148148
self.pr += self.step
149149
return cur_data
150150

@@ -167,7 +167,7 @@ def _neg_sample_by_point_wise_sampling(self, inter_feat, neg_iids):
167167
new_data[self.iid_field][pos_inter_num:] = neg_iids
168168
new_data = self.dataset.join(new_data)
169169
labels = torch.zeros(pos_inter_num * self.times)
170-
labels[: pos_inter_num] = 1.0
170+
labels[:pos_inter_num] = 1.0
171171
new_data.update(Interaction({self.label_field: labels}))
172172
return new_data
173173

@@ -203,8 +203,9 @@ class GeneralFullDataLoader(NegSampleMixin, AbstractDataLoader):
203203
"""
204204
dl_type = DataLoaderType.FULL
205205

206-
def __init__(self, config, dataset, sampler, neg_sample_args,
207-
batch_size=1, dl_format=InputType.POINTWISE, shuffle=False):
206+
def __init__(
207+
self, config, dataset, sampler, neg_sample_args, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False
208+
):
208209
if neg_sample_args['strategy'] != 'full':
209210
raise ValueError('neg_sample strategy in GeneralFullDataLoader() should be `full`')
210211

@@ -232,8 +233,9 @@ def __init__(self, config, dataset, sampler, neg_sample_args,
232233
self.uid_list = torch.tensor(self.uid_list)
233234
self.user_df = dataset.join(Interaction({uid_field: self.uid_list}))
234235

235-
super().__init__(config, dataset, sampler, neg_sample_args,
236-
batch_size=batch_size, dl_format=dl_format, shuffle=shuffle)
236+
super().__init__(
237+
config, dataset, sampler, neg_sample_args, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle
238+
)
237239

238240
def _set_user_property(self, uid, used_item, positive_item):
239241
if uid is None:
@@ -260,7 +262,7 @@ def _shuffle(self):
260262
self.logger.warnning('GeneralFullDataLoader can\'t shuffle')
261263

262264
def _next_batch_data(self):
263-
user_df = self.user_df[self.pr: self.pr + self.step]
265+
user_df = self.user_df[self.pr:self.pr + self.step]
264266
cur_data = self._neg_sampling(user_df)
265267
self.pr += self.step
266268
return cur_data

recbole/data/dataloader/knowledge_dataloader.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ class KGDataLoader(AbstractDataLoader):
3535
However, in :class:`KGDataLoader`, it's guaranteed to be ``True``.
3636
"""
3737

38-
def __init__(self, config, dataset, sampler,
39-
batch_size=1, dl_format=InputType.PAIRWISE, shuffle=False):
38+
def __init__(self, config, dataset, sampler, batch_size=1, dl_format=InputType.PAIRWISE, shuffle=False):
4039
self.sampler = sampler
4140
self.neg_sample_num = 1
4241

@@ -48,8 +47,7 @@ def __init__(self, config, dataset, sampler,
4847
self.neg_tid_field = self.neg_prefix + self.tid_field
4948
dataset.copy_field_property(self.neg_tid_field, self.tid_field)
5049

51-
super().__init__(config, dataset,
52-
batch_size=batch_size, dl_format=dl_format, shuffle=shuffle)
50+
super().__init__(config, dataset, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle)
5351

5452
def setup(self):
5553
"""Make sure that the :attr:`shuffle` is True. If :attr:`shuffle` is False, it will be changed to True
@@ -67,7 +65,7 @@ def _shuffle(self):
6765
self.dataset.kg_feat.shuffle()
6866

6967
def _next_batch_data(self):
70-
cur_data = self._neg_sampling(self.dataset.kg_feat[self.pr: self.pr + self.step])
68+
cur_data = self._neg_sampling(self.dataset.kg_feat[self.pr:self.pr + self.step])
7169
self.pr += self.step
7270
return cur_data
7371

@@ -112,28 +110,44 @@ class KnowledgeBasedDataLoader(AbstractDataLoader):
112110
and user-item interaction information.
113111
"""
114112

115-
def __init__(self, config, dataset, sampler, kg_sampler, neg_sample_args,
116-
batch_size=1, dl_format=InputType.POINTWISE, shuffle=False):
113+
def __init__(
114+
self,
115+
config,
116+
dataset,
117+
sampler,
118+
kg_sampler,
119+
neg_sample_args,
120+
batch_size=1,
121+
dl_format=InputType.POINTWISE,
122+
shuffle=False
123+
):
117124

118125
# using sampler
119-
self.general_dataloader = GeneralNegSampleDataLoader(config=config, dataset=dataset,
120-
sampler=sampler, neg_sample_args=neg_sample_args,
121-
batch_size=batch_size, dl_format=dl_format,
122-
shuffle=shuffle)
126+
self.general_dataloader = GeneralNegSampleDataLoader(
127+
config=config,
128+
dataset=dataset,
129+
sampler=sampler,
130+
neg_sample_args=neg_sample_args,
131+
batch_size=batch_size,
132+
dl_format=dl_format,
133+
shuffle=shuffle
134+
)
123135

124136
# using kg_sampler and dl_format is pairwise
125-
self.kg_dataloader = KGDataLoader(config, dataset, kg_sampler,
126-
batch_size=batch_size, dl_format=InputType.PAIRWISE, shuffle=True)
137+
self.kg_dataloader = KGDataLoader(
138+
config, dataset, kg_sampler, batch_size=batch_size, dl_format=InputType.PAIRWISE, shuffle=True
139+
)
127140

128141
self.state = None
129142

130-
super().__init__(config, dataset,
131-
batch_size=batch_size, dl_format=dl_format, shuffle=shuffle)
143+
super().__init__(config, dataset, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle)
132144

133145
def __iter__(self):
134146
if self.state is None:
135-
raise ValueError('The dataloader\'s state must be set when using the kg based dataloader, '
136-
'you should call set_mode() before __iter__()')
147+
raise ValueError(
148+
'The dataloader\'s state must be set when using the kg based dataloader, '
149+
'you should call set_mode() before __iter__()'
150+
)
137151
if self.state == KGDataLoaderState.KG:
138152
return self.kg_dataloader.__iter__()
139153
elif self.state == KGDataLoaderState.RS:

recbole/data/dataloader/neg_sample_mixin.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,16 @@ class NegSampleMixin(AbstractDataLoader):
3333
"""
3434
dl_type = DataLoaderType.NEGSAMPLE
3535

36-
def __init__(self, config, dataset, sampler, neg_sample_args,
37-
batch_size=1, dl_format=InputType.POINTWISE, shuffle=False):
36+
def __init__(
37+
self, config, dataset, sampler, neg_sample_args, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False
38+
):
3839
if neg_sample_args['strategy'] not in ['by', 'full']:
3940
raise ValueError(f"Neg_sample strategy [{neg_sample_args['strategy']}] has not been implemented.")
4041

4142
self.sampler = sampler
4243
self.neg_sample_args = neg_sample_args
4344

44-
super().__init__(config, dataset,
45-
batch_size=batch_size, dl_format=dl_format, shuffle=shuffle)
45+
super().__init__(config, dataset, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle)
4646

4747
def setup(self):
4848
"""Do batch size adaptation.
@@ -95,8 +95,9 @@ class NegSampleByMixin(NegSampleMixin):
9595
shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``.
9696
"""
9797

98-
def __init__(self, config, dataset, sampler, neg_sample_args,
99-
batch_size=1, dl_format=InputType.POINTWISE, shuffle=False):
98+
def __init__(
99+
self, config, dataset, sampler, neg_sample_args, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False
100+
):
100101
if neg_sample_args['strategy'] != 'by':
101102
raise ValueError('neg_sample strategy in GeneralInteractionBasedDataLoader() should be `by`')
102103

@@ -124,8 +125,9 @@ def __init__(self, config, dataset, sampler, neg_sample_args,
124125
else:
125126
raise ValueError(f'`neg sampling by` with dl_format [{dl_format}] not been implemented.')
126127

127-
super().__init__(config, dataset, sampler, neg_sample_args,
128-
batch_size=batch_size, dl_format=dl_format, shuffle=shuffle)
128+
super().__init__(
129+
config, dataset, sampler, neg_sample_args, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle
130+
)
129131

130132
def _neg_sample_by_pair_wise_sampling(self, *args):
131133
"""Pair-wise sampling.

0 commit comments

Comments
 (0)