Skip to content

Commit ed5bba6

Browse files
authored
Merge pull request #385 from yinhaofeng/ncf
Ncf
2 parents e2e21d2 + 7c4beab commit ed5bba6

File tree

18 files changed

+1077
-0
lines changed

18 files changed

+1077
-0
lines changed

.pre-commit-config.yaml

100644100755
File mode changed.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
mkdir Data
2+
pip3 install scipy
3+
wget https://paddlerec.bj.bcebos.com/ncf/Data.zip
4+
unzip Data/Data.zip -d Data/
5+
python3 get_train_data.py --num_neg 4 --train_data_path "Data/train_data.csv"
6+
python3 get_test_data.py
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
17+
filename = './Data/ml-1m.test.negative'
18+
f = open(filename, "r")
19+
lines = f.readlines()
20+
f.close()
21+
filename = './test_data.csv'
22+
f = open(filename, "w")
23+
for line in lines:
24+
line = line.strip().split("\t")
25+
user_id = line[0].strip("()").split(",")[0]
26+
positive_item = line[0].strip("()").split(",")[1]
27+
negative_item = []
28+
for item in line[1:]:
29+
negative_item.append(int(item))
30+
31+
f.write(user_id + "," + positive_item + "," + "1" + "\n")
32+
for item in negative_item:
33+
f.write(user_id + "," + str(item) + "," + "0" + "\n")
34+
35+
f.close()
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import scipy.sparse as sp
16+
import numpy as np
17+
from time import time
18+
import argparse
19+
20+
21+
def parse_args():
22+
parser = argparse.ArgumentParser(description="Run GMF.")
23+
parser.add_argument(
24+
'--path', nargs='?', default='Data/', help='Input data path.')
25+
parser.add_argument(
26+
'--dataset', nargs='?', default='ml-1m', help='Choose a dataset.')
27+
parser.add_argument(
28+
'--num_neg',
29+
type=int,
30+
default=4,
31+
help='Number of negative instances to pair with a positive instance.')
32+
parser.add_argument(
33+
'--train_data_path',
34+
type=str,
35+
default="Data/train_data.csv",
36+
help='train_data_path')
37+
return parser.parse_args()
38+
39+
40+
def get_train_data(filename, write_file, num_negatives):
41+
'''
42+
Read .rating file and Return dok matrix.
43+
The first line of .rating file is: num_users\t num_items
44+
'''
45+
# Get number of users and items
46+
num_users, num_items = 0, 0
47+
with open(filename, "r") as f:
48+
line = f.readline()
49+
while line != None and line != "":
50+
arr = line.split("\t")
51+
u, i = int(arr[0]), int(arr[1])
52+
num_users = max(num_users, u)
53+
num_items = max(num_items, i)
54+
line = f.readline()
55+
print("users_num:", num_users, "items_num:", num_items)
56+
# Construct matrix
57+
mat = sp.dok_matrix((num_users + 1, num_items + 1), dtype=np.float32)
58+
with open(filename, "r") as f:
59+
line = f.readline()
60+
while line != None and line != "":
61+
arr = line.split("\t")
62+
user, item, rating = int(arr[0]), int(arr[1]), float(arr[2])
63+
if (rating > 0):
64+
mat[user, item] = 1.0
65+
line = f.readline()
66+
67+
file = open(write_file, 'w')
68+
print("writing " + write_file)
69+
70+
for (u, i) in mat.keys():
71+
# positive instance
72+
user_input = str(u)
73+
item_input = str(i)
74+
label = str(1)
75+
sample = "{0},{1},{2}".format(user_input, item_input, label) + "\n"
76+
file.write(sample)
77+
# negative instances
78+
for t in range(num_negatives):
79+
j = np.random.randint(num_items)
80+
while (u, j) in mat.keys():
81+
j = np.random.randint(num_items)
82+
user_input = str(u)
83+
item_input = str(j)
84+
label = str(0)
85+
sample = "{0},{1},{2}".format(user_input, item_input, label) + "\n"
86+
file.write(sample)
87+
88+
89+
if __name__ == "__main__":
90+
args = parse_args()
91+
get_train_data(args.path + args.dataset + ".train.rating",
92+
args.train_data_path, args.num_neg)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# NCF使用的数据集
2+
3+
本数据集供NCF模型复现论文使用,使用的是初步处理过后的数据,分为两个数据集:ml-1m(即MovieLens数据集)和pinterest-20(即Pinterest数据集)
4+
每个数据集分为三个文件,后缀分别为:(.test.negative),(.test.rating),(.train.rating)
5+
6+
在train.rating和test.rating中的数据格式为:
7+
user_id + \t + item_id + \t + rating(用户评分) + \t + timestamp(时间戳)
8+
在test.negative中的数据格式为:
9+
(userID,itemID) + \t + negativeItemID1 + \t + negativeItemID2 …(包含99个negative样本)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
mkdir big_train
2+
cd big_train
3+
wget https://paddlerec.bj.bcebos.com/ncf/train_data.csv
4+
cd ..
5+
mkdir big_test
6+
cd big_test
7+
wget https://paddlerec.bj.bcebos.com/ncf/test_data.csv
8+
cd ..
9+
wget https://paddlerec.bj.bcebos.com/ncf/Data.zip

datasets/readme.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ sh data_process.sh
2525
|[senti_clas](https://baidu-nlp.bj.bcebos.com/sentiment_classification-dataset-1.0.0.tar.gz)|情感倾向分析(Sentiment Classification,简称Senta)针对带有主观描述的中文文本,可自动判断该文本的情感极性类别并给出相应的置信度。情感类型分为积极、消极。情感倾向分析能够帮助企业理解用户消费习惯、分析热点话题和危机舆情监控,为企业提供有利的决策支持|--|
2626
|[one_billion](http://www.statmt.org/lm-benchmark/)|拥有十亿个单词基准,为语言建模实验提供标准的训练和测试|[One Billion Word Benchmark for Measuring Progress in Statistical Language Modeling](https://arxiv.org/abs/1312.3005)|
2727
|[MIND](https://paddlerec.bj.bcebos.com/datasets/MIND/bigdata.zip)|MIND即MIcrosoft News Dataset的简写,MIND里的数据来自Microsoft News用户的行为日志。MIND的数据集里包含了1,000,000的用户以及这些用户与160,000的文章的交互行为。|[Microsoft(2020)](https://msnews.github.io)|
28+
|[movielens_pinterest_NCF](https://paddlerec.bj.bcebos.com/ncf/Data.zip)|论文原作者处理过的movielens数据集和pinterest数据集|[《Neural Collaborative Filtering 》](https://arxiv.org/pdf/1708.05031.pdf)|

models/recall/ncf/config.yaml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
runner:
16+
train_data_dir: "data/train"
17+
train_reader_path: "movielens_reader" # importlib format
18+
train_batch_size: 5
19+
model_save_path: "output_model_ncf"
20+
21+
use_gpu: False
22+
epochs: 3
23+
print_interval: 10
24+
25+
test_data_dir: "data/test"
26+
infer_reader_path: "movielens_reader" # importlib format
27+
infer_batch_size: 5
28+
infer_load_path: "output_model_ncf"
29+
infer_start_epoch: 2
30+
infer_end_epoch: 3
31+
32+
hyper_parameters:
33+
optimizer:
34+
class: adam
35+
learning_rate: 0.001
36+
num_users: 6040
37+
num_items: 3706
38+
mf_dim: 8
39+
mode: "NCF_NeuMF" # optional: NCF_NeuMF, NCF_GMF, NCF_MLP
40+
fc_layers: [64, 32, 16, 8]
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
runner:
16+
train_data_dir: "../../../datasets/movielens_pinterest_NCF/big_train"
17+
train_reader_path: "movielens_reader" # importlib format
18+
train_batch_size: 256
19+
model_save_path: "output_model_ncf"
20+
21+
use_gpu: False
22+
epochs: 20
23+
print_interval: 1
24+
25+
test_data_dir: "../../../datasets/movielens_pinterest_NCF/big_test"
26+
infer_reader_path: "movielens_reader" # importlib format
27+
infer_batch_size: 1
28+
infer_load_path: "output_model_ncf"
29+
infer_start_epoch: 19
30+
infer_end_epoch: 20
31+
32+
hyper_parameters:
33+
optimizer:
34+
class: adam
35+
learning_rate: 0.001
36+
num_users: 6040
37+
num_items: 3706
38+
mf_dim: 8
39+
mode: "NCF_NeuMF" # optional: NCF_NeuMF, NCF_GMF, NCF_MLP
40+
fc_layers: [64, 32, 16, 8]
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
4764,174,1
2+
4764,2958,0
3+
4764,452,0
4+
4764,1946,0
5+
4764,3208,0
6+
2044,2237,1
7+
2044,1998,0
8+
2044,328,0
9+
2044,1542,0
10+
2044,1932,0
11+
4276,65,1
12+
4276,3247,0
13+
4276,942,0
14+
4276,3666,0
15+
4276,2222,0
16+
3933,682,1
17+
3933,2451,0
18+
3933,3695,0
19+
3933,1643,0
20+
3933,3568,0
21+
1151,1265,1
22+
1151,118,0
23+
1151,2532,0
24+
1151,2083,0
25+
1151,2350,0
26+
1757,876,1
27+
1757,201,0
28+
1757,3633,0
29+
1757,1068,0
30+
1757,2549,0
31+
3370,276,1
32+
3370,2435,0
33+
3370,606,0
34+
3370,910,0
35+
3370,2146,0
36+
5137,1018,1
37+
5137,2163,0
38+
5137,3167,0
39+
5137,2315,0
40+
5137,3595,0
41+
3933,2831,1
42+
3933,2881,0
43+
3933,2949,0
44+
3933,3660,0
45+
3933,417,0
46+
3102,999,1
47+
3102,1902,0
48+
3102,2161,0
49+
3102,3042,0
50+
3102,1113,0
51+
2022,336,1
52+
2022,1672,0
53+
2022,2656,0
54+
2022,3649,0
55+
2022,883,0
56+
2664,655,1
57+
2664,3660,0
58+
2664,1711,0
59+
2664,3386,0
60+
2664,1668,0
61+
25,701,1
62+
25,32,0
63+
25,2482,0
64+
25,3177,0
65+
25,2767,0
66+
1738,1643,1
67+
1738,2187,0
68+
1738,228,0
69+
1738,650,0
70+
1738,3101,0
71+
5411,1241,1
72+
5411,2546,0
73+
5411,3019,0
74+
5411,3618,0
75+
5411,1674,0
76+
638,579,1
77+
638,3512,0
78+
638,783,0
79+
638,2111,0
80+
638,1880,0
81+
3554,200,1
82+
3554,2893,0
83+
3554,2428,0
84+
3554,969,0
85+
3554,2741,0
86+
4283,1074,1
87+
4283,3056,0
88+
4283,2032,0
89+
4283,405,0
90+
4283,1505,0
91+
5111,200,1
92+
5111,3488,0
93+
5111,477,0
94+
5111,2790,0
95+
5111,40,0
96+
3964,515,1
97+
3964,1528,0
98+
3964,2173,0
99+
3964,1701,0
100+
3964,2832,0

0 commit comments

Comments
 (0)