Skip to content

Commit 3046223

Browse files
committed
Conflict resolution
2 parents ca15393 + 623e14d commit 3046223

File tree

19 files changed

+21858
-11
lines changed

19 files changed

+21858
-11
lines changed

datasets/AmazonBook/preprocess.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import sys
15+
import os
16+
import json
17+
import numpy as np
18+
import argparse
19+
import random
20+
21+
parser = argparse.ArgumentParser()
22+
parser.add_argument(
23+
"-type", type=str, default="train", help="train|valid|test")
24+
parser.add_argument("-maxlen", type=int, default=20)
25+
26+
27+
def load_graph(source):
28+
graph = {}
29+
with open(source) as fr:
30+
for line in fr:
31+
conts = line.strip().split(',')
32+
user_id = int(conts[0])
33+
item_id = int(conts[1])
34+
time_stamp = int(conts[2])
35+
if user_id not in graph:
36+
graph[user_id] = []
37+
graph[user_id].append((item_id, time_stamp))
38+
39+
for user_id, value in graph.items():
40+
value.sort(key=lambda x: x[1])
41+
graph[user_id] = [x[0] for x in value]
42+
return graph
43+
44+
45+
if __name__ == "__main__":
46+
args = parser.parse_args()
47+
filelist = []
48+
for i in range(10):
49+
filelist.append(open(args.type + "/part-%d" % (i), "w"))
50+
action_graph = load_graph("book_data/book_" + args.type + ".txt")
51+
if args.type == "train":
52+
for uid, item_list in action_graph.items():
53+
for i in range(4, len(item_list)):
54+
if i >= args.maxlen:
55+
hist_item = item_list[i - args.maxlen:i]
56+
else:
57+
hist_item = item_list[:i]
58+
target_item = item_list[i]
59+
print(
60+
" ".join(["user_id:" + str(uid)] + [
61+
"hist_item:" + str(n) for n in hist_item
62+
] + ["target_item:" + str(target_item)]),
63+
file=random.choice(filelist))
64+
else:
65+
for uid, item_list in action_graph.items():
66+
k = int(len(item_list) * 0.8)
67+
if k >= args.maxlen:
68+
hist_item = item_list[k - args.maxlen:k]
69+
else:
70+
hist_item = item_list[:k]
71+
target_item = item_list[k:]
72+
print(
73+
" ".join(["user_id:" + str(uid), "target_item:0"] + [
74+
"hist_item:" + str(n) for n in hist_item
75+
] + ["eval_item:" + str(n) for n in target_item]),
76+
file=random.choice(filelist))

datasets/AmazonBook/run.sh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
2+
wget https://paddlerec.bj.bcebos.com/datasets/AmazonBook/AmazonBook.tar.gz
3+
4+
tar -xvf AmazonBook.tar.gz
5+
6+
rm -rf train valid
7+
mkdir train
8+
mkdir valid
9+
10+
mv book_data/book_train.txt train
11+
python preprocess.py -type valid -maxlen 20

datasets/readme.md

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,12 @@ sh data_process.sh
2828
|[movielens_pinterest_NCF](https://paddlerec.bj.bcebos.com/ncf/Data.zip)|论文原作者处理过的movielens数据集和pinterest数据集|[《Neural Collaborative Filtering 》](https://arxiv.org/pdf/1708.05031.pdf)|
2929
|[Anime](https://paddlerec.bj.bcebos.com/datasets/Anime/archive.zip)|该数据集包含73,516个用户对12,294个动漫的用户偏好数据。每个用户都可以将动漫添加到列表中并给它一个评分,该数据集是这些评分的汇总。|[Kaggle](https://www.kaggle.com/CooperUnion/anime-recommendations-database)|
3030
|[LFM-1b](https://paddlerec.bj.bcebos.com/datasets/LFM_1b/LFM-1b.zip)|此数据集包含由Last.FM的120,000多个用户创建的十亿多个音乐收听记录。每条收听记录均以艺术家,专辑和曲目名称为特征,并包含一个时间戳。|[ICMR 2016](http://www.cp.jku.at/datasets/LFM-1b/)|
31-
|[LFM-1b UGP](https://paddlerec.bj.bcebos.com/datasets/LFM_1b_UGP/LFM-1b_UGP.zip)|LFM-1b数据集的用户类型档案,作为LFM-1b的补充扩展|[ISM 2017](http://www.cp.jku.at/datasets/LFM-1b/)|
32-
|[Jester](https://paddlerec.bj.bcebos.com/datasets/Jester/JesterDataset3.zip)|此数据集包含Jester Joke Recommender系统用户对笑话的匿名评分。|[UC Berkeley](http://eigentaste.berkeley.edu/dataset/)|
33-
|[Steam](https://paddlerec.bj.bcebos.com/datasets/steam/steam_reviews.json.gz)|该数据集是Steam的评论和游戏信息,其中包含7,793,069条评论,2,567,538位用户和32,135个游戏。除评论文本外,数据还包括每个评论中用户的游戏时间。|[ICDM 2018](https://github.com/kang205/SASRec)|
34-
|[Douban](https://paddlerec.bj.bcebos.com/datasets/Douban/DMSC.csv)|豆瓣电影是一个中文网站,允许互联网用户分享他们对电影的评论和观点。用户可以在电影上发表简短或长时间的评论并给他们打分。该数据集包含“豆瓣电影”网站中28部电影的200万条简短评论。|[Kaggle](https://www.kaggle.com/utmhikari/doubanmovieshortcomments)|
35-
|[TaFeng](https://paddlerec.bj.bcebos.com/datasets/tafeng/ta_feng_all_months_merged.csv)|该数据集包含2000年11月至2001年2月中国杂货店的交易数据。|[Kaggle](https://www.kaggle.com/chiranjivdas09/ta-feng-grocery-dataset)|
36-
|[Retailrocket](https://paddlerec.bj.bcebos.com/datasets/Retailrocket/Retailrocket.zip)|数据是从真实的电子商务网站中收集的。它是原始数据,即没有任何内容转换,但是,由于保密问题,所有值都被哈希化。|[Kaggle](https://www.kaggle.com/retailrocket/ecommerce-dataset)|
37-
|[Netflix](https://paddlerec.bj.bcebos.com/datasets/Netflix/Netflix.zip)|这是Netflix竞赛中使用的官方数据集。|[Kaggle](https://www.kaggle.com/netflix-inc/netflix-prize-data)|
38-
|[FourSquare](https://paddlerec.bj.bcebos.com/datasets/FourSquare/FourSquare.zip)|此数据集包含在纽约和东京进行的大约10个月收集的签到。每个签到都有其时间戳,GPS坐标及其语义相关联。|[Kaggle](https://www.kaggle.com/chetanism/foursquare-nyc-and-tokyo-checkin-dataset)|
31+
|[LFM-1b UGP](https://paddlerec.bj.bcebos.com/datasets/LFM_1b_UGP/LFM-1b_UGP.zip)|LFM-1b数据集的用户类型档案,作为LFM-1b的补充扩展|[ISM 2017](http://www.cp.jku.at/datasets/LFM-1b/)|
32+
|[Jester](https://paddlerec.bj.bcebos.com/datasets/Jester/JesterDataset3.zip)|此数据集包含Jester Joke Recommender系统用户对笑话的匿名评分。|[UC Berkeley](http://eigentaste.berkeley.edu/dataset/)|
33+
|[Steam](https://paddlerec.bj.bcebos.com/datasets/steam/steam_reviews.json.gz)|该数据集是Steam的评论和游戏信息,其中包含7,793,069条评论,2,567,538位用户和32,135个游戏。除评论文本外,数据还包括每个评论中用户的游戏时间。|[ICDM 2018](https://github.com/kang205/SASRec)|
34+
|[Douban](https://paddlerec.bj.bcebos.com/datasets/Douban/DMSC.csv)|豆瓣电影是一个中文网站,允许互联网用户分享他们对电影的评论和观点。用户可以在电影上发表简短或长时间的评论并给他们打分。该数据集包含“豆瓣电影”网站中28部电影的200万条简短评论。|[Kaggle](https://www.kaggle.com/utmhikari/doubanmovieshortcomments)|
35+
|[TaFeng](https://paddlerec.bj.bcebos.com/datasets/tafeng/ta_feng_all_months_merged.csv)|该数据集包含2000年11月至2001年2月中国杂货店的交易数据。|[Kaggle](https://www.kaggle.com/chiranjivdas09/ta-feng-grocery-dataset)|
36+
|[Retailrocket](https://paddlerec.bj.bcebos.com/datasets/Retailrocket/Retailrocket.zip)|数据是从真实的电子商务网站中收集的。它是原始数据,即没有任何内容转换,但是,由于保密问题,所有值都被哈希化。|[Kaggle](https://www.kaggle.com/retailrocket/ecommerce-dataset)|
37+
|[Netflix](https://paddlerec.bj.bcebos.com/datasets/Netflix/Netflix.zip)|这是Netflix竞赛中使用的官方数据集。|[Kaggle](https://www.kaggle.com/netflix-inc/netflix-prize-data)|
38+
|[FourSquare](https://paddlerec.bj.bcebos.com/datasets/FourSquare/FourSquare.zip)|此数据集包含在纽约和东京进行的大约10个月收集的签到。每个签到都有其时间戳,GPS坐标及其语义相关联。|[Kaggle](https://www.kaggle.com/chetanism/foursquare-nyc-and-tokyo-checkin-dataset)|
39+
|[AmazonBook](https://paddlerec.bj.bcebos.com/datasets/AmazonBook/AmazonBook.tar.gz)|论文原作者处理过的AmazonBook数据集 |[《Controllable Multi-Interest Framework for Recommendation》](https://arxiv.org/abs/2005.09347)|

doc/imgs/mind.png

91.5 KB
Loading

models/recall/mind/README.md

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# MIND(Multi-Interest Network with Dynamic Routing)
2+
3+
以下是本例的简要目录结构及说明:
4+
```
5+
├── data #样例数据
6+
│ ├── demo #demo训练数据
7+
│ │ └── demo.txt
8+
│ └── valid #demo测试数据
9+
│ └── part-0
10+
├── config.yaml #demo数据配置
11+
├── config_bigdata.yaml #全量数据配置
12+
├── infer.py #评测动态图
13+
├── dygraph_model.py #构建动态图
14+
├── mind_reader.py #训练数据reader
15+
├── mind_infer_reader.py #评测数据reader
16+
├── net.py #模型核心组网(动静合一)
17+
├── static_infer.py #评测静态图
18+
└── static_model.py #构建静态图
19+
```
20+
21+
注:在阅读该示例前,建议您先了解以下内容:
22+
23+
[paddlerec入门教程](https://github.com/PaddlePaddle/PaddleRec/blob/master/README.md)
24+
25+
## 内容
26+
- [模型简介](#模型简介)
27+
- [数据准备](#数据准备)
28+
- [运行环境](#运行环境)
29+
- [快速开始](#快速开始)
30+
- [模型组网](#模型组网)
31+
- [效果复现](#效果复现)
32+
- [进阶使用](#进阶使用)
33+
- [FAQ](#FAQ)
34+
35+
## 模型简介
36+
本例实现了基于动态路由的用户多兴趣网络,如下图所示:
37+
<p align="center">
38+
<img align="center" src="../../../doc/imgs/mind.png">
39+
<p>
40+
Multi-Interest Network with Dynamic Routing (MIND) 是通过构建用户和商品向量在统一的向量空间的多个用户兴趣向量,以表达用户多样的兴趣分布。然后通过向量召回技术,利用这多个兴趣向量去检索出TopK个与其近邻的商品向量,得到 TopK个 用户感兴趣的商品。其核心是一个基于胶囊网络和动态路由的(B2I Dynamic Routing)Multi-Interest Extractor Layer。
41+
42+
推荐参考论文:[http://cn.arxiv.org/abs/1904.08030](http://cn.arxiv.org/abs/1904.08030)
43+
44+
## 数据准备
45+
在模型目录的data目录下为您准备了快速运行的示例数据,训练数据、测试数据、词表文件依次保存在data/train, data/test文件夹中。若需要使用全量数据可以参考下方效果复现部分。
46+
47+
训练数据的格式如下:
48+
```
49+
0,17978,0
50+
0,901,1
51+
0,97224,2
52+
0,774,3
53+
0,85757,4
54+
```
55+
分别表示uid、item_id和点击的顺序(时间戳)
56+
57+
测试数据的格式如下:
58+
```
59+
user_id:487766 target_item:0 hist_item:17784 hist_item:126 hist_item:36 hist_item:124 hist_item:34 hist_item:1 hist_item:134 hist_item:6331 hist_item:141 hist_item:4336 hist_item:1373 eval_item:1062 eval_item:867 eval_item:62
60+
user_id:487793 target_item:0 hist_item:153428 hist_item:132997 hist_item:155723 hist_item:66546 hist_item:335397 hist_item:1926 eval_item:1122 eval_item:10105
61+
user_id:487805 target_item:0 hist_item:291025 hist_item:25190 hist_item:2820 hist_item:26047 hist_item:47259 hist_item:36376 eval_item:260145 eval_item:83865
62+
user_id:487811 target_item:0 hist_item:180837 hist_item:202701 hist_item:184587 hist_item:211642 eval_item:101621 eval_item:55716
63+
user_id:487820 target_item:0 hist_item:268524 hist_item:44318 hist_item:35153 hist_item:70847 eval_item:238318
64+
user_id:487825 target_item:0 hist_item:35602 hist_item:4353 hist_item:1540 hist_item:72921 eval_item:501
65+
```
66+
其中`hist_item``eval_item`均是变长序列,读取方式可以看`mind_infer_reader.py`
67+
68+
## 运行环境
69+
PaddlePaddle>=2.0
70+
71+
python 2.7/3.5/3.6/3.7
72+
73+
os : windows/linux/macos
74+
75+
## 快速开始
76+
77+
在mind模型目录的快速执行命令如下:
78+
```
79+
# 进入模型目录
80+
# cd models/recall/mind # 在任意目录均可运行
81+
# 动态图训练
82+
python -u ../../../tools/trainer.py -m config.yaml
83+
# 动态图预测
84+
python -u infer.py -m config.yaml -top_n 50 #对测试数据进行预测,并通过faiss召回候选结果评测Reacll、NDCG、HitRate指标
85+
86+
# 静态图训练
87+
python -u ../../../tools/static_trainer.py -m config.yaml # 全量数据运行config_bigdata.yaml
88+
# 静态图预测
89+
python -u static_infer.py -m config.yaml -top_n 50 #对测试数据进行预测,并通过faiss召回候选结果评测Reacll、NDCG、HitRate指标
90+
```
91+
92+
## 模型组网
93+
94+
细节见上面[模型简介](#模型简介)部分
95+
96+
### 效果复现
97+
由于原始论文没有提供实验的复现细节,为了方便使用者能够快速的跑通每一个模型,我们使用论文[ComiRec](https://arxiv.org/abs/2005.09347)提供的AmazonBook数据集和训练任务进行复现。我们在每个模型下都提供了样例数据。如果需要复现readme中的效果,请按如下步骤依次操作即可。
98+
99+
在全量数据下模型的指标如下:
100+
| 模型 | batch_size | epoch_num| Recall@50 | NDCG@50 | HitRate@50 |Time of each epoch |
101+
| :------| :------ | :------ | :------| :------ | :------| :------ |
102+
| mind(静态图) | 128 | 6 | 5.61% | 8.96% | 11.81% | -- |
103+
| mind(动态图) | 128 | 6 | 5.54% | 8.85% | 11.75% | -- |
104+
105+
1. 确认您当前所在目录为PaddleRec/models/recall/mind
106+
2. 进入paddlerec/datasets/AmazonBook目录下执行run.sh脚本,会下载处理完成的AmazonBook数据集,并解压到指定目录
107+
```bash
108+
cd ../../../datasets/AmazonBook
109+
sh run.sh
110+
```
111+
3. 切回模型目录,执行命令运行全量数据
112+
```bash
113+
cd - # 切回模型目录
114+
# 动态图训练
115+
python -u ../../../tools/trainer.py -m config_bigdata.yaml # 全量数据运行config_bigdata
116+
python -u infer.py -m config_bigdata.yaml # 全量数据运行config_bigdata
117+
```
118+
119+
## 进阶使用
120+
121+
## FAQ
122+
123+
## 参考
124+
125+
数据集及训练任务参考了[ComiRec](https://github.com/THUDM/ComiRec)

models/recall/mind/config.yaml

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
runner:
16+
train_data_dir: "data/train"
17+
train_reader_path: "mind_reader" # importlib format
18+
use_gpu: True
19+
use_auc: False
20+
train_batch_size: 128
21+
epochs: 2
22+
print_interval: 500
23+
model_save_path: "output_model_mind"
24+
infer_batch_size: 128
25+
infer_reader_path: "mind_infer_reader" # importlib format
26+
test_data_dir: "data/valid"
27+
infer_load_path: "output_model_mind"
28+
infer_start_epoch: 0
29+
infer_end_epoch: 1
30+
31+
# distribute_config
32+
# sync_mode: "async"
33+
# split_file_list: False
34+
# thread_num: 1
35+
36+
37+
# hyper parameters of user-defined network
38+
hyper_parameters:
39+
# optimizer config
40+
optimizer:
41+
class: Adam
42+
learning_rate: 0.005
43+
# strategy: async
44+
# user-defined <key, value> pairs
45+
item_count: 367983
46+
embedding_dim: 64
47+
hidden_size: 64
48+
neg_samples: 1280
49+
maxlen: 20
50+
pow_p: 1.0
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
runner:
16+
train_data_dir: "../../../datasets/AmazonBook/train"
17+
train_reader_path: "mind_reader" # importlib format
18+
use_gpu: True
19+
use_auc: False
20+
train_batch_size: 128
21+
epochs: 6
22+
print_interval: 500
23+
model_save_path: "output_model_mind"
24+
infer_batch_size: 128
25+
infer_reader_path: "mind_infer_reader" # importlib format
26+
test_data_dir: "../../../datasets/AmazonBook/valid"
27+
infer_load_path: "output_model_mind"
28+
infer_start_epoch: 0
29+
infer_end_epoch: 1
30+
31+
# distribute_config
32+
# sync_mode: "async"
33+
# split_file_list: False
34+
# thread_num: 1
35+
36+
37+
# hyper parameters of user-defined network
38+
hyper_parameters:
39+
# optimizer config
40+
optimizer:
41+
class: Adam
42+
learning_rate: 0.005
43+
# strategy: async
44+
# user-defined <key, value> pairs
45+
item_count: 367983
46+
embedding_dim: 64
47+
hidden_size: 64
48+
neg_samples: 1280
49+
maxlen: 20
50+
pow_p: 1.0

0 commit comments

Comments
 (0)