Skip to content

Commit 7314eab

Browse files
committed
add mind_model
1 parent 4a6006e commit 7314eab

File tree

14 files changed

+23307
-0
lines changed

14 files changed

+23307
-0
lines changed

doc/imgs/mind.png

91.5 KB
Loading

models/recall/mind/README.md

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# MIND(Multi-Interest Network with Dynamic Routing)
2+
3+
以下是本例的简要目录结构及说明:
4+
```shell
5+
├── data #样例数据
6+
│ ├── demo #demo训练数据
7+
│ │ └── demo.txt
8+
│ ├── processs.py #处理全量数据的脚本
9+
│ ├── run.sh #全量数据下载的脚本
10+
│ └── valid #demo测试数据
11+
│ └── part-0
12+
├── config.yaml #数据配置
13+
├── dygraph_model.py #构建动态图
14+
├── evaluate_dygraph.py #评测动态图
15+
├── evaluate_reader.py #评测数据reader
16+
├── evaluate_static.py #评测静态图
17+
├── mind_reader.py #训练数据reader
18+
├── net.py #模型核心组网(动静合一)
19+
└── static_model.py #构建静态图
20+
```
21+
22+
注:在阅读该示例前,建议您先了解以下内容:
23+
24+
[paddlerec入门教程](https://github.com/PaddlePaddle/PaddleRec/blob/master/README.md)
25+
26+
## 内容
27+
- [模型简介](#模型简介)
28+
- [数据准备](#数据准备)
29+
- [运行环境](#运行环境)
30+
- [快速开始](#快速开始)
31+
- [模型组网](#模型组网)
32+
- [效果复现](#效果复现)
33+
- [进阶使用](#进阶使用)
34+
- [FAQ](#FAQ)
35+
36+
## 模型简介
37+
本例实现了基于动态路由的用户多兴趣网络,如下图所示:
38+
<p align="center">
39+
<img align="center" src="../../../doc/imgs/mind.png">
40+
<p>
41+
Multi-Interest Network with Dynamic Routing (MIND) 是通过构建用户和商品向量在统一的向量空间的多个用户兴趣向量,以表达用户多样的兴趣分布。然后通过向量召回技术,利用这多个兴趣向量去检索出TopK个与其近邻的商品向量,得到 TopK个 用户感兴趣的商品。其核心是一个基于胶囊网络和动态路由的(B2I Dynamic Routing)Multi-Interest Extractor Layer。
42+
43+
推荐参考论文:[http://cn.arxiv.org/abs/1904.08030](http://cn.arxiv.org/abs/1904.08030)
44+
45+
## 数据准备
46+
在模型目录的data目录下为您准备了快速运行的示例数据,训练数据、测试数据、词表文件依次保存在data/train, data/test文件夹中。若需要使用全量数据可以参考下方效果复现部分。
47+
48+
训练数据的格式如下:
49+
```
50+
0,17978,0
51+
0,901,1
52+
0,97224,2
53+
0,774,3
54+
0,85757,4
55+
```
56+
分别表示uid、item_id和点击的顺序(时间戳)
57+
58+
测试数据的格式如下:
59+
```
60+
user_id:543354 hist_item:143963 hist_item:157508 hist_item:105486 hist_item:40502 hist_item:167813 hist_item:233564 hist_item:221866 hist_item:280310 hist_item:61638 hist_item:158494 hist_item:74449 hist_item:283630 hist_item:135155 hist_item:96176 hist_item:20139 hist_item:89420 hist_item:247990 hist_item:126605 target_item:172183 target_item:114193 target_item:79966 target_item:134420 target_item:50557
61+
user_id:543362 hist_item:119546 hist_item:78597 hist_item:86809 hist_item:63551 target_item:326165
62+
user_id:543366 hist_item:45463 hist_item:9903 hist_item:3956 hist_item:49726 target_item:199426
63+
```
64+
其中`hist_item``target_item`均是变长序列,读取方式可以看`evaluate_reader.py`
65+
66+
## 运行环境
67+
PaddlePaddle>=2.0
68+
69+
python 2.7/3.5/3.6/3.7
70+
71+
os : windows/linux/macos
72+
73+
## 快速开始
74+
75+
在mind模型目录的快速执行命令如下:
76+
```
77+
# 进入模型目录
78+
# cd models/recall/word2vec # 在任意目录均可运行
79+
# 动态图训练
80+
python -u ../../../tools/trainer.py -m config.yaml
81+
# 动态图预测
82+
python -u evaluate_dygraph.py -m config.yaml -top_n 50 #对测试数据进行预测,并通过faiss召回候选结果评测Reacll、NDCG、HitRate指标
83+
84+
# 静态图训练
85+
python -u ../../../tools/static_trainer.py -m config.yaml # 全量数据运行config_bigdata.yaml
86+
# 静态图预测
87+
python -u evaluate_static.py -m config.yaml -top_n 50 #对测试数据进行预测,并通过faiss召回候选结果评测Reacll、NDCG、HitRate指标
88+
```
89+
90+
## 模型组网
91+
92+
细节见上面[模型简介](#模型简介)部分
93+
94+
### 效果复现
95+
由于原始论文没有提供实验的复现细节,为了方便使用者能够快速的跑通每一个模型,我们使用论文[ComiRec](https://arxiv.org/abs/2005.09347)提供的AmazonBook数据集和训练任务进行复现。我们在每个模型下都提供了样例数据。如果需要复现readme中的效果,请按如下步骤依次操作即可。
96+
97+
在全量数据下模型的指标如下:
98+
| 模型 | batch_size | epoch_num| Recall@50 | NDCG@50 | HitRate@50 |Time of each epoch |
99+
| :------| :------ | :------ | :------| :------ | :------| :------ |
100+
| mind(静态图) | 128 | 6 | 4.61% | 11.28%| 18.92%| -- |
101+
| mind(动态图) | 128 | 6 | 4.57% | 11.25%| 18.99%| -- |
102+
103+
1. 确认您当前所在目录为PaddleRec/models/recall/mind
104+
2. 进入data目录下执行run.sh脚本,会下载处理完成的AmazonBook数据集,并解压到指定目录
105+
```bash
106+
cd ./data
107+
sh run.sh
108+
```
109+
3. 切回模型目录,执行命令运行全量数据
110+
```bash
111+
d - # 切回模型目录
112+
# 动态图训练
113+
python -u ../../../tools/trainer.py -m config.yaml # 全量数据运行config.yaml
114+
python -u evaluate_dygraph.py -m config.yaml # 全量数据运行config.yaml
115+
```
116+
117+
## 进阶使用
118+
119+
## FAQ
120+
121+
## 参考
122+
123+
数据集及训练任务参考了[ComiRec](https://github.com/THUDM/ComiRec)

models/recall/mind/config.yaml

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
runner:
16+
train_data_dir: "data/train"
17+
train_reader_path: "mind_reader" # importlib format
18+
use_gpu: True
19+
use_auc: False
20+
train_batch_size: 128
21+
epochs: 6
22+
print_interval: 500
23+
model_save_path: "output_model_mind"
24+
infer_batch_size: 128
25+
infer_reader_path: "evaluate_reader" # importlib format
26+
test_data_dir: "data/valid"
27+
infer_load_path: "output_model_mind"
28+
infer_start_epoch: 0
29+
infer_end_epoch: 4
30+
31+
# distribute_config
32+
# sync_mode: "async"
33+
# split_file_list: False
34+
# thread_num: 1
35+
36+
37+
# hyper parameters of user-defined network
38+
hyper_parameters:
39+
# optimizer config
40+
optimizer:
41+
class: Adam
42+
learning_rate: 0.005
43+
# strategy: async
44+
# user-defined <key, value> pairs
45+
item_count: 367983
46+
embedding_dim: 64
47+
hidden_size: 64
48+
neg_samples: 1280
49+
maxlen: 20
50+
pow_p: 1.0
51+
distributed_embedding: 0
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import sys
15+
import os
16+
import json
17+
import numpy as np
18+
import argparse
19+
import random
20+
21+
parser = argparse.ArgumentParser()
22+
parser.add_argument(
23+
"-type", type=str, default="train", help="train|valid|test")
24+
parser.add_argument("-maxlen", type=int, default=20)
25+
26+
27+
def load_graph(source):
28+
graph = {}
29+
with open(source) as fr:
30+
for line in fr:
31+
conts = line.strip().split(',')
32+
user_id = int(conts[0])
33+
item_id = int(conts[1])
34+
time_stamp = int(conts[2])
35+
if user_id not in graph:
36+
graph[user_id] = []
37+
graph[user_id].append((item_id, time_stamp))
38+
39+
for user_id, value in graph.items():
40+
value.sort(key=lambda x: x[1])
41+
graph[user_id] = [x[0] for x in value]
42+
return graph
43+
44+
45+
if __name__ == "__main__":
46+
args = parser.parse_args()
47+
filelist = []
48+
for i in range(10):
49+
filelist.append(open(args.type + "/part-%d" % (i), "w"))
50+
action_graph = load_graph("data/book_data/book_" + args.type + ".txt")
51+
if args.type == "train":
52+
for uid, item_list in action_graph.items():
53+
for i in range(4, len(item_list)):
54+
if i >= args.maxlen:
55+
hist_item = item_list[i - args.maxlen:i]
56+
else:
57+
hist_item = item_list[:i]
58+
target_item = item_list[i]
59+
print(
60+
" ".join(["user_id:" + str(uid)] + [
61+
"hist_item:" + str(n) for n in hist_item
62+
] + ["target_item:" + str(target_item)]),
63+
file=random.choice(filelist))
64+
else:
65+
for uid, item_list in action_graph.items():
66+
k = int(len(item_list) * 0.8)
67+
if k >= args.maxlen:
68+
hist_item = item_list[k - args.maxlen:k]
69+
else:
70+
hist_item = item_list[:k]
71+
target_item = item_list[k:]
72+
print(
73+
" ".join(["user_id:" + str(uid), "target_item:0"] + [
74+
"hist_item:" + str(n) for n in hist_item
75+
] + ["eval_item:" + str(n) for n in target_item]),
76+
file=random.choice(filelist))

models/recall/mind/data/run.sh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
2+
wget https://cloud.tsinghua.edu.cn/f/e5c4211255bc40cba828/?dl=1
3+
4+
tar -xvf data.tar.gz
5+
6+
rm -rf train valid
7+
mkdir train
8+
mkdir valid
9+
10+
mv data/book_data/book_train.txt train
11+
python preprocess.py -type valid -maxlen 20
12+
rm -rf data.tar.gz
13+
rm -rf data

0 commit comments

Comments
 (0)