Skip to content

Commit 233e6c4

Browse files
committed
fix format
1 parent 8f74bef commit 233e6c4

File tree

6 files changed

+7
-11
lines changed

6 files changed

+7
-11
lines changed

models/recall/mhcn/config.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
runner:
16-
train_data_dir: "data/train"
16+
train_data_dir: "data"
1717
rating_file: "data/ratings_100.txt"
1818
relation_file: "data/trusts_100.txt"
1919
train_batch_size: 8
@@ -23,13 +23,14 @@ runner:
2323
epochs: 5
2424
print_interval: 1
2525
model_save_path: "output_model_mhcn"
26-
test_data_dir: "data/test"
26+
test_data_dir: "data"
2727
infer_batch_size: 8
2828
infer_reader_path: "lastfm_reader" # importlib format
2929
infer_load_path: "output_model_mhcn"
3030
infer_start_epoch: 0
3131
infer_end_epoch: 5
3232
top_k: 10
33+
is_train: True
3334

3435
# hyper parameters of user-defined network
3536
hyper_parameters:

models/recall/mhcn/config_bigdata.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
runner:
16-
train_data_dir: "data/train"
16+
train_data_dir: "data"
1717
rating_file: "../../../datasets/LastFM_MHCN/lastfm/ratings.txt"
1818
relation_file: "../../../datasets/LastFM_MHCN/lastfm/trusts.txt"
1919
train_batch_size: 2000
@@ -23,13 +23,14 @@ runner:
2323
epochs: 120
2424
print_interval: 1
2525
model_save_path: "output_model_mhcn_all"
26-
test_data_dir: "data/test"
26+
test_data_dir: "data"
2727
infer_batch_size: 2000
2828
infer_reader_path: "lastfm_reader" # importlib format
2929
infer_load_path: "output_model_mhcn_all"
3030
infer_start_epoch: 0
3131
infer_end_epoch: 117
3232
top_k: 10
33+
is_train: True
3334

3435
# hyper parameters of user-defined network
3536
hyper_parameters:

models/recall/mhcn/data/test/test.txt

Whitespace-only changes.

models/recall/mhcn/data/train/train.txt

Whitespace-only changes.

models/recall/mhcn/lastfm_reader.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,7 @@ def crossValidation(data, k, binarized=False):
386386
class RecDataset(IterableDataset):
387387
def __init__(self, file_list, config):
388388
super(RecDataset, self).__init__()
389-
self.file_list = file_list
390-
self.is_train = True if "train" in file_list or file_list else False
389+
self.is_train = config.get("runner.is_train", True)
391390
self.trainingSet = loadDataSet(
392391
config.get("runner.rating_file", None),
393392
bTest=False,

models/recall/mhcn/net.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,6 @@ def __init__(self, n_layers=2, emb_size=50, config=None):
3333
self.data, self.social = self.recDataset.get_dataset()
3434
self.num_users, self.num_items, _ = self.data.trainingSize()
3535

36-
# user num: 1891
37-
print("user num: ", self.num_users)
38-
# item num: 15438
39-
print("item num: ", self.num_items)
40-
4136
self.userAdjacency = None
4237
self.itemAdjacency = None
4338

0 commit comments

Comments
 (0)