Skip to content

Commit b64f0eb

Browse files
authored
Merge branch 'master' into mmoe_fix_0917
2 parents 1039612 + d99cb4f commit b64f0eb

File tree

8 files changed

+558
-22
lines changed

8 files changed

+558
-22
lines changed

models/recall/gru4rec/README.md

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# GRU4REC
2+
3+
以下是本例的简要目录结构及说明:
4+
5+
```
6+
├── data #样例数据及数据处理相关文件
7+
├── train
8+
├── small_train.txt # 样例训练数据
9+
├── test
10+
├── small_test.txt # 样例测试数据
11+
├── convert_format.py # 数据转换脚本
12+
├── download.py # 数据下载脚本
13+
├── preprocess.py # 数据预处理脚本
14+
├── text2paddle.py # paddle训练数据生成脚本
15+
├── __init__.py
16+
├── README.md # 文档
17+
├── model.py #模型文件
18+
├── config.yaml #配置文件
19+
├── data_prepare.sh #一键数据处理脚本
20+
├── rsc15_reader.py #reader
21+
```
22+
23+
注:在阅读该示例前,建议您先了解以下内容:
24+
25+
[paddlerec入门教程](https://github.com/PaddlePaddle/PaddleRec/blob/master/README.md)
26+
27+
28+
---
29+
## 内容
30+
31+
- [模型简介](#模型简介)
32+
- [数据准备](#数据准备)
33+
- [运行环境](#运行环境)
34+
- [快速开始](#快速开始)
35+
- [论文复现](#论文复现)
36+
- [进阶使用](#进阶使用)
37+
- [FAQ](#FAQ)
38+
39+
## 模型简介
40+
GRU4REC模型的介绍可以参阅论文[Session-based Recommendations with Recurrent Neural Networks](https://arxiv.org/abs/1511.06939)
41+
42+
论文的贡献在于首次将RNN(GRU)运用于session-based推荐,相比传统的KNN和矩阵分解,效果有明显的提升。
43+
44+
论文的核心思想是在一个session中,用户点击一系列item的行为看做一个序列,用来训练RNN模型。预测阶段,给定已知的点击序列作为输入,预测下一个可能点击的item。
45+
46+
session-based推荐应用场景非常广泛,比如用户的商品浏览、新闻点击、地点签到等序列数据。
47+
48+
本模型配置默认使用demo数据集,若进行精度验证,请参考[论文复现](#论文复现)部分。
49+
50+
本项目支持功能
51+
52+
训练:单机CPU、单机单卡GPU、本地模拟参数服务器训练、增量训练,配置请参考 [启动训练](https://github.com/PaddlePaddle/PaddleRec/blob/master/doc/train.md)
53+
54+
预测:单机CPU、单机单卡GPU;配置请参考[PaddleRec 离线预测](https://github.com/PaddlePaddle/PaddleRec/blob/master/doc/predict.md)
55+
56+
## 数据处理
57+
本示例中数据处理共包含三步:
58+
- Step1: 原始数据数据集下载
59+
```
60+
cd data/
61+
python download.py
62+
```
63+
- Step2: 数据预处理及格式转换。
64+
1. 以session_id为key合并原始数据集,得到每个session的日期,及顺序点击列表。
65+
2. 过滤掉长度为1的session;过滤掉点击次数小于5的items。
66+
3. 训练集、测试集划分。原始数据集里最新日期七天内的作为训练集,更早之前的数据作为测试集。
67+
```
68+
python preprocess.py
69+
python convert_format.py
70+
```
71+
这一步之后,会在data/目录下得到两个文件,rsc15_train_tr_paddle.txt为原始训练文件,rsc15_test_paddle.txt为原始测试文件。格式如下所示:
72+
```
73+
214536502 214536500 214536506 214577561
74+
214662742 214662742 214825110 214757390 214757407 214551617
75+
214716935 214774687 214832672
76+
214836765 214706482
77+
214701242 214826623
78+
214826835 214826715
79+
214838855 214838855
80+
214576500 214576500 214576500
81+
214821275 214821275 214821371 214821371 214821371 214717089 214563337 214706462 214717436 214743335 214826837 214819762
82+
214717867 21471786
83+
```
84+
- Step3: 生成字典并整理数据路径。这一步会根据训练和测试文件生成字典和对应的paddle输入文件,并将训练文件统一放在data/all_train目录下,测试文件统一放在data/all_test目录下。
85+
```
86+
mkdir raw_train_data && mkdir raw_test_data
87+
mv rsc15_train_tr_paddle.txt raw_train_data/ && mv rsc15_test_paddle.txt raw_test_data/
88+
mkdir all_train && mkdir all_test
89+
90+
python text2paddle.py raw_train_data/ raw_test_data/ all_train all_test vocab.txt
91+
```
92+
93+
方便起见,我们提供了一键式数据生成脚本:
94+
```
95+
sh data_prepare.sh
96+
```
97+
98+
## 运行环境
99+
100+
PaddlePaddle>=1.7.2
101+
102+
python 2.7/3.5/3.6/3.7
103+
104+
PaddleRec >=0.1
105+
106+
os : windows/linux/macos
107+
108+
## 快速开始
109+
110+
### 单机训练
111+
112+
在config.yaml文件中设置好设备,epochs等。
113+
```
114+
runner:
115+
- name: cpu_train_runner
116+
class: train
117+
device: cpu # gpu
118+
epochs: 10
119+
save_checkpoint_interval: 1
120+
save_inference_interval: 1
121+
save_checkpoint_path: "increment_gru4rec"
122+
save_inference_path: "inference_gru4rec"
123+
save_inference_feed_varnames: ["src_wordseq", "dst_wordseq"] # feed vars of save inference
124+
save_inference_fetch_varnames: ["mean_0.tmp_0", "top_k_0.tmp_0"]
125+
print_interval: 10
126+
phases: [train]
127+
128+
```
129+
130+
### 单机预测
131+
132+
在config.yaml文件中设置好设备,epochs等。
133+
```
134+
- name: cpu_infer_runner
135+
class: infer
136+
init_model_path: "increment_gru4rec"
137+
device: cpu # gpu
138+
phases: [infer]
139+
```
140+
141+
### 运行
142+
```
143+
python -m paddlerec.run -m paddlerec.models.recall.gru4rec
144+
```
145+
146+
### 结果展示
147+
148+
样例数据训练结果展示:
149+
150+
```
151+
Running SingleStartup.
152+
Running SingleRunner.
153+
2020-09-22 03:31:18,167-INFO: [Train], epoch: 0, batch: 10, time_each_interval: 4.34s, RecallCnt: [1669.], cost: [8.366313], InsCnt: [16228.], Acc(Recall@20): [0.10284693]
154+
2020-09-22 03:31:21,982-INFO: [Train], epoch: 0, batch: 20, time_each_interval: 3.82s, RecallCnt: [3168.], cost: [8.170701], InsCnt: [31943.], Acc(Recall@20): [0.09917666]
155+
2020-09-22 03:31:25,797-INFO: [Train], epoch: 0, batch: 30, time_each_interval: 3.81s, RecallCnt: [4855.], cost: [8.017181], InsCnt: [47892.], Acc(Recall@20): [0.10137393]
156+
...
157+
epoch 0 done, use time: 6003.78719687, global metrics: cost=[4.4394927], InsCnt=23622448.0 RecallCnt=14547467.0 Acc(Recall@20)=0.6158323218660487
158+
2020-09-22 05:11:17,761-INFO: save epoch_id:0 model into: "inference_gru4rec/0"
159+
...
160+
epoch 9 done, use time: 6009.97707605, global metrics: cost=[4.069373], InsCnt=236237470.0 RecallCnt=162838200.0 Acc(Recall@20)=0.6892988086157644
161+
2020-09-22 20:17:11,358-INFO: save epoch_id:9 model into: "inference_gru4rec/9"
162+
PaddleRec Finish
163+
```
164+
165+
样例数据预测结果展示:
166+
```
167+
Running SingleInferStartup.
168+
Running SingleInferRunner.
169+
load persistables from increment_gru4rec/9
170+
2020-09-23 03:46:21,081-INFO: [Infer] batch: 20, time_each_interval: 3.68s, RecallCnt: [24875.], InsCnt: [35581.], Acc(Recall@20): [0.6991091]
171+
Infer infer of epoch 9 done, use time: 5.25408315659, global metrics: InsCnt=52551.0 RecallCnt=36720.0 Acc(Recall@20)=0.698749785922247
172+
...
173+
Infer infer of epoch 0 done, use time: 5.20699501038, global metrics: InsCnt=52551.0 RecallCnt=33664.0 Acc(Recall@20)=0.6405967536298073
174+
PaddleRec Finish
175+
```
176+
177+
## 论文复现
178+
179+
用原论文的完整数据复现论文效果需要在config.yaml修改超参:
180+
- batch_size: 修改config.yaml中dataset_train数据集的batch_size为500。
181+
- epochs: 修改config.yaml中runner的epochs为10。
182+
- 数据源:修改config.yaml中dataset_train数据集的data_path为"{workspace}/data/all_train",dataset_test数据集的data_path为"{workspace}/data/all_test"。
183+
184+
使用gpu训练10轮 测试结果为
185+
186+
epoch | 测试recall@20 | 速度(s)
187+
-- | -- | --
188+
1 | 0.6406 | 6003
189+
2 | 0.6727 | 6007
190+
3 | 0.6831 | 6108
191+
4 | 0.6885 | 6025
192+
5 | 0.6913 | 6019
193+
6 | 0.6931 | 6011
194+
7 | 0.6952 | 6015
195+
8 | 0.6968 | 6076
196+
9 | 0.6972 | 6076
197+
10 | 0.6987| 6009
198+
199+
修改后运行方案:修改config.yaml中的'workspace'为config.yaml的目录位置,执行
200+
```
201+
python -m paddlerec.run -m /home/your/dir/config.yaml #调试模式 直接指定本地config的绝对路径
202+
```
203+
204+
## 进阶使用
205+
206+
## FAQ

models/recall/gru4rec/config.yaml

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,19 @@ workspace: "models/recall/gru4rec"
1616

1717
dataset:
1818
- name: dataset_train
19-
batch_size: 5
20-
type: QueueDataset
19+
batch_size: 500
20+
type: DataLoader # QueueDataset
2121
data_path: "{workspace}/data/train"
2222
data_converter: "{workspace}/rsc15_reader.py"
2323
- name: dataset_infer
24-
batch_size: 5
25-
type: QueueDataset
24+
batch_size: 500
25+
type: DataLoader #QueueDataset
2626
data_path: "{workspace}/data/test"
2727
data_converter: "{workspace}/rsc15_reader.py"
2828

2929
hyper_parameters:
30-
vocab_size: 1000
30+
recall_k: 20
31+
vocab_size: 37483
3132
hid_size: 100
3233
emb_lr_x: 10.0
3334
gru_lr_x: 1.0
@@ -40,30 +41,34 @@ hyper_parameters:
4041
strategy: async
4142

4243
#use infer_runner mode and modify 'phase' below if infer
43-
mode: train_runner
44+
mode: [cpu_train_runner, cpu_infer_runner]
4445
#mode: infer_runner
4546

4647
runner:
47-
- name: train_runner
48+
- name: cpu_train_runner
4849
class: train
4950
device: cpu
50-
epochs: 3
51-
save_checkpoint_interval: 2
52-
save_inference_interval: 4
53-
save_checkpoint_path: "increment"
54-
save_inference_path: "inference"
51+
epochs: 10
52+
save_checkpoint_interval: 1
53+
save_inference_interval: 1
54+
save_checkpoint_path: "increment_gru4rec"
55+
save_inference_path: "inference_gru4rec"
56+
save_inference_feed_varnames: ["src_wordseq", "dst_wordseq"] # feed vars of save inference
57+
save_inference_fetch_varnames: ["mean_0.tmp_0", "top_k_0.tmp_0"]
5558
print_interval: 10
56-
- name: infer_runner
59+
phases: [train]
60+
- name: cpu_infer_runner
5761
class: infer
58-
init_model_path: "increment/0"
62+
init_model_path: "increment_gru4rec"
5963
device: cpu
64+
phases: [infer]
6065

6166
phase:
6267
- name: train
6368
model: "{workspace}/model.py"
6469
dataset_name: dataset_train
6570
thread_num: 1
66-
#- name: infer
67-
# model: "{workspace}/model.py"
68-
# dataset_name: dataset_infer
69-
# thread_num: 1
71+
- name: infer
72+
model: "{workspace}/model.py"
73+
dataset_name: dataset_infer
74+
thread_num: 1
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sys
16+
import codecs
17+
18+
19+
def convert_format(input, output):
20+
with codecs.open(input, "r", encoding='utf-8') as rf:
21+
with codecs.open(output, "w", encoding='utf-8') as wf:
22+
last_sess = -1
23+
sign = 1
24+
i = 0
25+
for l in rf:
26+
i = i + 1
27+
if i == 1:
28+
continue
29+
if (i % 1000000 == 1):
30+
print(i)
31+
tokens = l.strip().split()
32+
if (int(tokens[0]) != last_sess):
33+
if (sign):
34+
sign = 0
35+
wf.write(tokens[1] + " ")
36+
else:
37+
wf.write("\n" + tokens[1] + " ")
38+
last_sess = int(tokens[0])
39+
else:
40+
wf.write(tokens[1] + " ")
41+
42+
43+
input = "rsc15_train_tr.txt"
44+
output = "rsc15_train_tr_paddle.txt"
45+
input2 = "rsc15_test.txt"
46+
output2 = "rsc15_test_paddle.txt"
47+
convert_format(input, output)
48+
convert_format(input2, output2)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import requests
16+
import sys
17+
import time
18+
import os
19+
20+
lasttime = time.time()
21+
FLUSH_INTERVAL = 0.1
22+
23+
24+
def progress(str, end=False):
25+
global lasttime
26+
if end:
27+
str += "\n"
28+
lasttime = 0
29+
if time.time() - lasttime >= FLUSH_INTERVAL:
30+
sys.stdout.write("\r%s" % str)
31+
lasttime = time.time()
32+
sys.stdout.flush()
33+
34+
35+
def _download_file(url, savepath, print_progress):
36+
r = requests.get(url, stream=True)
37+
total_length = r.headers.get('content-length')
38+
39+
if total_length is None:
40+
with open(savepath, 'wb') as f:
41+
shutil.copyfileobj(r.raw, f)
42+
else:
43+
with open(savepath, 'wb') as f:
44+
dl = 0
45+
total_length = int(total_length)
46+
starttime = time.time()
47+
if print_progress:
48+
print("Downloading %s" % os.path.basename(savepath))
49+
for data in r.iter_content(chunk_size=4096):
50+
dl += len(data)
51+
f.write(data)
52+
if print_progress:
53+
done = int(50 * dl / total_length)
54+
progress("[%-50s] %.2f%%" %
55+
('=' * done, float(100 * dl) / total_length))
56+
if print_progress:
57+
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
58+
59+
60+
_download_file("https://paddlerec.bj.bcebos.com/gnn%2Fyoochoose-clicks.dat",
61+
"./yoochoose-clicks.dat", True)

0 commit comments

Comments
 (0)