Skip to content

Commit 4aaa458

Browse files
committed
fix codestyle problems
1 parent 7314eab commit 4aaa458

File tree

18 files changed

+1184
-2594
lines changed

18 files changed

+1184
-2594
lines changed

datasets/AmazonBook/preprocess.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import sys
15+
import os
16+
import json
17+
import numpy as np
18+
import argparse
19+
import random
20+
21+
parser = argparse.ArgumentParser()
22+
parser.add_argument(
23+
"-type", type=str, default="train", help="train|valid|test")
24+
parser.add_argument("-maxlen", type=int, default=20)
25+
26+
27+
def load_graph(source):
28+
graph = {}
29+
with open(source) as fr:
30+
for line in fr:
31+
conts = line.strip().split(',')
32+
user_id = int(conts[0])
33+
item_id = int(conts[1])
34+
time_stamp = int(conts[2])
35+
if user_id not in graph:
36+
graph[user_id] = []
37+
graph[user_id].append((item_id, time_stamp))
38+
39+
for user_id, value in graph.items():
40+
value.sort(key=lambda x: x[1])
41+
graph[user_id] = [x[0] for x in value]
42+
return graph
43+
44+
45+
if __name__ == "__main__":
46+
args = parser.parse_args()
47+
filelist = []
48+
for i in range(10):
49+
filelist.append(open(args.type + "/part-%d" % (i), "w"))
50+
action_graph = load_graph("book_data/book_" + args.type + ".txt")
51+
if args.type == "train":
52+
for uid, item_list in action_graph.items():
53+
for i in range(4, len(item_list)):
54+
if i >= args.maxlen:
55+
hist_item = item_list[i - args.maxlen:i]
56+
else:
57+
hist_item = item_list[:i]
58+
target_item = item_list[i]
59+
print(
60+
" ".join(["user_id:" + str(uid)] + [
61+
"hist_item:" + str(n) for n in hist_item
62+
] + ["target_item:" + str(target_item)]),
63+
file=random.choice(filelist))
64+
else:
65+
for uid, item_list in action_graph.items():
66+
k = int(len(item_list) * 0.8)
67+
if k >= args.maxlen:
68+
hist_item = item_list[k - args.maxlen:k]
69+
else:
70+
hist_item = item_list[:k]
71+
target_item = item_list[k:]
72+
print(
73+
" ".join(["user_id:" + str(uid), "target_item:0"] + [
74+
"hist_item:" + str(n) for n in hist_item
75+
] + ["eval_item:" + str(n) for n in target_item]),
76+
file=random.choice(filelist))

datasets/AmazonBook/run.sh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
2+
wget https://paddlerec.bj.bcebos.com/datasets/AmazonBook/AmazonBook.tar.gz
3+
4+
tar -xvf AmazonBook.tar.gz
5+
6+
rm -rf train valid
7+
mkdir train
8+
mkdir valid
9+
10+
mv book_data/book_train.txt train
11+
python preprocess.py -type valid -maxlen 20

datasets/readme.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@ sh data_process.sh
2626
|[one_billion](http://www.statmt.org/lm-benchmark/)|拥有十亿个单词基准,为语言建模实验提供标准的训练和测试|[One Billion Word Benchmark for Measuring Progress in Statistical Language Modeling](https://arxiv.org/abs/1312.3005)|
2727
|[MIND](https://paddlerec.bj.bcebos.com/datasets/MIND/bigdata.zip)|MIND即MIcrosoft News Dataset的简写,MIND里的数据来自Microsoft News用户的行为日志。MIND的数据集里包含了1,000,000的用户以及这些用户与160,000的文章的交互行为。|[Microsoft(2020)](https://msnews.github.io)|
2828
|[movielens_pinterest_NCF](https://paddlerec.bj.bcebos.com/ncf/Data.zip)|论文原作者处理过的movielens数据集和pinterest数据集|[《Neural Collaborative Filtering 》](https://arxiv.org/pdf/1708.05031.pdf)|
29+
|[AmazonBook](https://paddlerec.bj.bcebos.com/datasets/AmazonBook/AmazonBook.tar.gz)|论文原作者处理过的AmazonBook数据集 |[《Controllable Multi-Interest Framework for Recommendation》](https://arxiv.org/abs/2005.09347)|

models/recall/mind/README.md

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,16 @@
55
├── data #样例数据
66
│ ├── demo #demo训练数据
77
│ │ └── demo.txt
8-
│ ├── processs.py #处理全量数据的脚本
9-
│ ├── run.sh #全量数据下载的脚本
108
│ └── valid #demo测试数据
119
│ └── part-0
12-
├── config.yaml #数据配置
10+
├── config.yaml #demo数据配置
11+
├── config_bigdata.yaml #全量数据配置
12+
├── infer.py #评测动态图
1313
├── dygraph_model.py #构建动态图
14-
├── evaluate_dygraph.py #评测动态图
15-
├── evaluate_reader.py #评测数据reader
16-
├── evaluate_static.py #评测静态图
1714
├── mind_reader.py #训练数据reader
15+
├── mind_infer_reader.py #评测数据reader
1816
├── net.py #模型核心组网(动静合一)
17+
├── static_infer.py #评测静态图
1918
└── static_model.py #构建静态图
2019
```
2120

@@ -57,11 +56,14 @@ Multi-Interest Network with Dynamic Routing (MIND) 是通过构建用户和商
5756

5857
测试数据的格式如下:
5958
```
60-
user_id:543354 hist_item:143963 hist_item:157508 hist_item:105486 hist_item:40502 hist_item:167813 hist_item:233564 hist_item:221866 hist_item:280310 hist_item:61638 hist_item:158494 hist_item:74449 hist_item:283630 hist_item:135155 hist_item:96176 hist_item:20139 hist_item:89420 hist_item:247990 hist_item:126605 target_item:172183 target_item:114193 target_item:79966 target_item:134420 target_item:50557
61-
user_id:543362 hist_item:119546 hist_item:78597 hist_item:86809 hist_item:63551 target_item:326165
62-
user_id:543366 hist_item:45463 hist_item:9903 hist_item:3956 hist_item:49726 target_item:199426
59+
user_id:487766 target_item:0 hist_item:17784 hist_item:126 hist_item:36 hist_item:124 hist_item:34 hist_item:1 hist_item:134 hist_item:6331 hist_item:141 hist_item:4336 hist_item:1373 eval_item:1062 eval_item:867 eval_item:62
60+
user_id:487793 target_item:0 hist_item:153428 hist_item:132997 hist_item:155723 hist_item:66546 hist_item:335397 hist_item:1926 eval_item:1122 eval_item:10105
61+
user_id:487805 target_item:0 hist_item:291025 hist_item:25190 hist_item:2820 hist_item:26047 hist_item:47259 hist_item:36376 eval_item:260145 eval_item:83865
62+
user_id:487811 target_item:0 hist_item:180837 hist_item:202701 hist_item:184587 hist_item:211642 eval_item:101621 eval_item:55716
63+
user_id:487820 target_item:0 hist_item:268524 hist_item:44318 hist_item:35153 hist_item:70847 eval_item:238318
64+
user_id:487825 target_item:0 hist_item:35602 hist_item:4353 hist_item:1540 hist_item:72921 eval_item:501
6365
```
64-
其中`hist_item``target_item`均是变长序列,读取方式可以看`evaluate_reader.py`
66+
其中`hist_item``eval_item`均是变长序列,读取方式可以看`mind_infer_reader.py`
6567

6668
## 运行环境
6769
PaddlePaddle>=2.0
@@ -75,16 +77,16 @@ os : windows/linux/macos
7577
在mind模型目录的快速执行命令如下:
7678
```
7779
# 进入模型目录
78-
# cd models/recall/word2vec # 在任意目录均可运行
80+
# cd models/recall/mind # 在任意目录均可运行
7981
# 动态图训练
8082
python -u ../../../tools/trainer.py -m config.yaml
8183
# 动态图预测
82-
python -u evaluate_dygraph.py -m config.yaml -top_n 50 #对测试数据进行预测,并通过faiss召回候选结果评测Reacll、NDCG、HitRate指标
84+
python -u infer.py -m config.yaml -top_n 50 #对测试数据进行预测,并通过faiss召回候选结果评测Reacll、NDCG、HitRate指标
8385
8486
# 静态图训练
8587
python -u ../../../tools/static_trainer.py -m config.yaml # 全量数据运行config_bigdata.yaml
8688
# 静态图预测
87-
python -u evaluate_static.py -m config.yaml -top_n 50 #对测试数据进行预测,并通过faiss召回候选结果评测Reacll、NDCG、HitRate指标
89+
python -u static_infer.py -m config.yaml -top_n 50 #对测试数据进行预测,并通过faiss召回候选结果评测Reacll、NDCG、HitRate指标
8890
```
8991

9092
## 模型组网
@@ -97,21 +99,21 @@ python -u evaluate_static.py -m config.yaml -top_n 50 #对测试数据进行
9799
在全量数据下模型的指标如下:
98100
| 模型 | batch_size | epoch_num| Recall@50 | NDCG@50 | HitRate@50 |Time of each epoch |
99101
| :------| :------ | :------ | :------| :------ | :------| :------ |
100-
| mind(静态图) | 128 | 6 | 4.61% | 11.28%| 18.92%| -- |
101-
| mind(动态图) | 128 | 6 | 4.57% | 11.25%| 18.99%| -- |
102+
| mind(静态图) | 128 | 6 | 5.61% | 8.96% | 11.81% | -- |
103+
| mind(动态图) | 128 | 6 | 5.54% | 8.85% | 11.75% | -- |
102104

103105
1. 确认您当前所在目录为PaddleRec/models/recall/mind
104-
2. 进入data目录下执行run.sh脚本,会下载处理完成的AmazonBook数据集,并解压到指定目录
106+
2. 进入paddlerec/datasets/AmazonBook目录下执行run.sh脚本,会下载处理完成的AmazonBook数据集,并解压到指定目录
105107
```bash
106-
cd ./data
108+
cd ../../../datasets/AmazonBook
107109
sh run.sh
108110
```
109111
3. 切回模型目录,执行命令运行全量数据
110112
```bash
111-
d - # 切回模型目录
113+
cd - # 切回模型目录
112114
# 动态图训练
113-
python -u ../../../tools/trainer.py -m config.yaml # 全量数据运行config.yaml
114-
python -u evaluate_dygraph.py -m config.yaml # 全量数据运行config.yaml
115+
python -u ../../../tools/trainer.py -m config_bigdata.yaml # 全量数据运行config_bigdata
116+
python -u infer.py -m config_bigdata.yaml # 全量数据运行config_bigdata
115117
```
116118

117119
## 进阶使用

models/recall/mind/config.yaml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ runner:
1818
use_gpu: True
1919
use_auc: False
2020
train_batch_size: 128
21-
epochs: 6
21+
epochs: 2
2222
print_interval: 500
2323
model_save_path: "output_model_mind"
2424
infer_batch_size: 128
25-
infer_reader_path: "evaluate_reader" # importlib format
25+
infer_reader_path: "mind_infer_reader" # importlib format
2626
test_data_dir: "data/valid"
2727
infer_load_path: "output_model_mind"
2828
infer_start_epoch: 0
29-
infer_end_epoch: 4
29+
infer_end_epoch: 1
3030

3131
# distribute_config
3232
# sync_mode: "async"
@@ -48,4 +48,3 @@ hyper_parameters:
4848
neg_samples: 1280
4949
maxlen: 20
5050
pow_p: 1.0
51-
distributed_embedding: 0
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
runner:
16+
train_data_dir: "../../../datasets/AmazonBook/train"
17+
train_reader_path: "mind_reader" # importlib format
18+
use_gpu: True
19+
use_auc: False
20+
train_batch_size: 128
21+
epochs: 6
22+
print_interval: 500
23+
model_save_path: "output_model_mind"
24+
infer_batch_size: 128
25+
infer_reader_path: "mind_infer_reader" # importlib format
26+
test_data_dir: "../../../datasets/AmazonBook/valid"
27+
infer_load_path: "output_model_mind"
28+
infer_start_epoch: 0
29+
infer_end_epoch: 1
30+
31+
# distribute_config
32+
# sync_mode: "async"
33+
# split_file_list: False
34+
# thread_num: 1
35+
36+
37+
# hyper parameters of user-defined network
38+
hyper_parameters:
39+
# optimizer config
40+
optimizer:
41+
class: Adam
42+
learning_rate: 0.005
43+
# strategy: async
44+
# user-defined <key, value> pairs
45+
item_count: 367983
46+
embedding_dim: 64
47+
hidden_size: 64
48+
neg_samples: 1280
49+
maxlen: 20
50+
pow_p: 1.0
Lines changed: 1 addition & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -11,66 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import sys
15-
import os
16-
import json
17-
import numpy as np
18-
import argparse
19-
import random
20-
21-
parser = argparse.ArgumentParser()
22-
parser.add_argument(
23-
"-type", type=str, default="train", help="train|valid|test")
24-
parser.add_argument("-maxlen", type=int, default=20)
25-
26-
27-
def load_graph(source):
28-
graph = {}
29-
with open(source) as fr:
30-
for line in fr:
31-
conts = line.strip().split(',')
32-
user_id = int(conts[0])
33-
item_id = int(conts[1])
34-
time_stamp = int(conts[2])
35-
if user_id not in graph:
36-
graph[user_id] = []
37-
graph[user_id].append((item_id, time_stamp))
38-
39-
for user_id, value in graph.items():
40-
value.sort(key=lambda x: x[1])
41-
graph[user_id] = [x[0] for x in value]
42-
return graph
43-
44-
45-
if __name__ == "__main__":
46-
args = parser.parse_args()
47-
filelist = []
48-
for i in range(10):
49-
filelist.append(open(args.type + "/part-%d" % (i), "w"))
50-
action_graph = load_graph("data/book_data/book_" + args.type + ".txt")
51-
if args.type == "train":
52-
for uid, item_list in action_graph.items():
53-
for i in range(4, len(item_list)):
54-
if i >= args.maxlen:
55-
hist_item = item_list[i - args.maxlen:i]
56-
else:
57-
hist_item = item_list[:i]
58-
target_item = item_list[i]
59-
print(
60-
" ".join(["user_id:" + str(uid)] + [
61-
"hist_item:" + str(n) for n in hist_item
62-
] + ["target_item:" + str(target_item)]),
63-
file=random.choice(filelist))
64-
else:
65-
for uid, item_list in action_graph.items():
66-
k = int(len(item_list) * 0.8)
67-
if k >= args.maxlen:
68-
hist_item = item_list[k - args.maxlen:k]
69-
else:
70-
hist_item = item_list[:k]
71-
target_item = item_list[k:]
72-
print(
73-
" ".join(["user_id:" + str(uid), "target_item:0"] + [
74-
"hist_item:" + str(n) for n in hist_item
75-
] + ["eval_item:" + str(n) for n in target_item]),
76-
file=random.choice(filelist))

models/recall/mind/data/run.sh

Lines changed: 0 additions & 13 deletions
This file was deleted.

0 commit comments

Comments
 (0)