Skip to content

Commit 713a120

Browse files
authored
Merge pull request #427 from frankwhzhang/sharebtm_0506
add Sharebtm 2.0
2 parents 9364512 + 9990d10 commit 713a120

File tree

12 files changed

+594
-13
lines changed

12 files changed

+594
-13
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,10 @@
6969
| 排序 | [FGCNN](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5/models/rank/fgcnn/) ||||| [1.8.5](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5) | [WWW 2019][Feature Generation by Convolutional Neural Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1904.04447.pdf) |
7070
| 排序 | [Fibinet](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5/models/rank/fibinet/) ||||| [1.8.5](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5) | [RecSys19][FiBiNET: Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction]( https://arxiv.org/pdf/1905.09433.pdf) |
7171
| 排序 | [Flen](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5/models/rank/flen/) ||||| [1.8.5](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5) | [2019][FLEN: Leveraging Field for Scalable CTR Prediction]( https://arxiv.org/pdf/1911.04690.pdf) |
72-
| 多任务 | PLE ||||| 1.8.5 | [RecSys 2020][Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations](https://dl.acm.org/doi/abs/10.1145/3383313.3412236) |
72+
| 多任务 | [PLE](models/multitask/ple) ||||| 2.0 | [RecSys 2020][Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations](https://dl.acm.org/doi/abs/10.1145/3383313.3412236) |
7373
| 多任务 | [ESMM](models/multitask/esmm/) ||||| 2.0 | [SIGIR 2018][Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate](https://arxiv.org/abs/1804.07931) |
7474
| 多任务 | [MMOE](models/multitask/mmoe/) ||||| 2.0 | [KDD 2018][Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts](https://dl.acm.org/doi/abs/10.1145/3219819.3220007) |
75-
| 多任务 | [ShareBottom](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5/models/multitask/share-bottom/) ||||| [1.8.5](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5) | [1998][Multitask learning](http://reports-archive.adm.cs.cmu.edu/anon/1997/CMU-CS-97-203.pdf) |
75+
| 多任务 | [ShareBottom](models/multitask/share_bottom/) ||||| 2.0 | [1998][Multitask learning](http://reports-archive.adm.cs.cmu.edu/anon/1997/CMU-CS-97-203.pdf) |
7676
| 重排序 | [Listwise](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5/models/rerank/listwise/) |||| x | [1.8.5](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5) | [2019][Sequential Evaluation and Generation Framework for Combinatorial Recommender System](https://arxiv.org/pdf/1902.00245.pdf) |
7777

7878

models/multitask/ple/README.md

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@
3434
- [FAQ](#FAQ)
3535

3636
## 模型简介
37-
多任务模型通过学习不同任务的联系和差异,可提高每个任务的学习效率和质量。多任务学习的的框架广泛采用shared-bottom的结构,不同任务间共用底部的隐层。这种结构本质上可以减少过拟合的风险,但是效果上可能受到任务差异和数据分布带来的影响。 论文[《Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts》]( https://www.kdd.org/kdd2018/accepted-papers/view/modeling-task-relationships-in-multi-task-learning-with-multi-gate-mixture- )中提出了一个Multi-gate Mixture-of-Experts(MMOE)的多任务学习结构。
37+
多任务模型通过学习不同任务的联系和差异,可提高每个任务的学习效率和质量。但在多任务场景中经常出现跷跷板现象,即有些任务表现良好,有些任务表现变差。 论文[《Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations》](https://dl.acm.org/doi/abs/10.1145/3383313.3412236 ) ,论文提出了Progressive Layered Extraction (简称PLE),来解决多任务学习的跷跷板现象。
38+
39+
我们在Paddlepaddle定义PLE的网络结构,在开源数据集Census-income Data上验证模型效果。
3840

39-
## 数据准备
40-
我们在开源数据集Census-income Data上验证模型效果,在模型目录的data目录下为您准备了快速运行的示例数据,若需要使用全量数据可以参考下方[效果复现](#效果复现)部分.
4141
数据的格式如下:
4242
生成的格式以逗号为分割点
4343
```
@@ -55,7 +55,7 @@ os : windows/linux/macos
5555
本文提供了样例数据可以供您快速体验,在任意目录下均可执行。在mmoe模型目录的快速执行命令如下:
5656
```bash
5757
# 进入模型目录
58-
# cd models/multitask/mmoe # 在任意目录均可运行
58+
# cd models/multitask/ple # 在任意目录均可运行
5959
# 动态图训练
6060
python -u ../../../tools/trainer.py -m config.yaml # 全量数据运行config_bigdata.yaml
6161
# 动态图预测
@@ -68,20 +68,15 @@ python -u ../../../tools/static_infer.py -m config.yaml
6868
```
6969

7070
## 模型组网
71-
MMOE模型刻画了任务相关性,基于共享表示来学习特定任务的函数,避免了明显增加参数的缺点。模型的主要组网结构如下:
72-
[MMoE](https://dl.acm.org/doi/abs/10.1145/3219819.3220007):
73-
<p align="center">
74-
<img align="center" src="../../../doc/imgs/mmoe.png">
75-
<p>
7671

7772
### 效果复现
7873
为了方便使用者能够快速的跑通每一个模型,我们在每个模型下都提供了样例数据。如果需要复现readme中的效果,请按如下步骤依次操作即可。
7974
在全量数据下模型的指标如下:
8075
| 模型 | auc_marital | batch_size | epoch_num | Time of each epoch |
8176
| :------| :------ | :------ | :------| :------ |
82-
| MMOE | 0.99 | 32 | 100 | 约1分钟 |
77+
| PLE | 0.99 | 32 | 100 | 约1分钟 |
8378

84-
1. 确认您当前所在目录为PaddleRec/models/multitask/mmoe
79+
1. 确认您当前所在目录为PaddleRec/models/multitask/ple
8580
2. 进入paddlerec/datasets/census目录下,执行该脚本,会从国内源的服务器上下载我们预处理完成的census全量数据集,并解压到指定文件夹。
8681
``` bash
8782
cd ../../../datasets/census
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# ShareBottom
2+
3+
以下是本例的简要目录结构及说明:
4+
5+
```
6+
├── data # 文档
7+
├── train #训练数据
8+
├── train_data.txt
9+
├── test #测试数据
10+
├── test_data.txt
11+
├── __init__.py
12+
├── README.md #文档
13+
├── config.yaml # sample数据配置
14+
├── config_bigdata.yaml # 全量数据配置
15+
├── census_reader.py # 数据读取程序
16+
├── net.py # 模型核心组网(动静统一)
17+
├── static_model.py # 构建静态图
18+
├── dygraph_model.py # 构建动态图
19+
```
20+
21+
注:在阅读该示例前,建议您先了解以下内容:
22+
23+
[paddlerec入门教程](https://github.com/PaddlePaddle/PaddleRec/blob/master/README.md)
24+
25+
## 内容
26+
27+
- [模型简介](#模型简介)
28+
- [数据准备](#数据准备)
29+
- [运行环境](#运行环境)
30+
- [快速开始](#快速开始)
31+
- [模型组网](#模型组网)
32+
- [效果复现](#效果复现)
33+
- [进阶使用](#进阶使用)
34+
- [FAQ](#FAQ)
35+
36+
## 模型简介
37+
share_bottom是多任务学习的基本框架,其特点是对于不同的任务,底层的参数和网络结构是共享的,这种结构的优点是极大地减少网络的参数数量的情况下也能很好地对多任务进行学习,但缺点也很明显,由于底层的参数和网络结构是完全共享的,因此对于相关性不高的两个任务会导致优化冲突,从而影响模型最终的结果。后续很多Neural-based的多任务模型都是基于share_bottom发展而来的,如MMOE等模型可以改进share_bottom在多任务之间相关性低导致模型效果差的缺点。
38+
39+
我们在Paddlepaddle实现share_bottom网络结构,并在开源数据集Census-income Data上验证模型效果。
40+
41+
## 数据准备
42+
我们在开源数据集Census-income Data上验证模型效果,在模型目录的data目录下为您准备了快速运行的示例数据,若需要使用全量数据可以参考下方[效果复现](#效果复现)部分.
43+
数据的格式如下:
44+
生成的格式以逗号为分割点
45+
```
46+
0,0,73,0,0,0,0,1700.09,0,0
47+
```
48+
49+
## 运行环境
50+
PaddlePaddle>=2.0
51+
52+
python 2.7/3.5/3.6/3.7
53+
54+
os : windows/linux/macos
55+
56+
## 快速开始
57+
本文提供了样例数据可以供您快速体验,在任意目录下均可执行。在mmoe模型目录的快速执行命令如下:
58+
```bash
59+
# 进入模型目录
60+
# cd models/multitask/share_bottom # 在任意目录均可运行
61+
# 动态图训练
62+
python -u ../../../tools/trainer.py -m config.yaml # 全量数据运行config_bigdata.yaml
63+
# 动态图预测
64+
python -u ../../../tools/infer.py -m config.yaml
65+
66+
# 静态图训练
67+
python -u ../../../tools/static_trainer.py -m config.yaml # 全量数据运行config_bigdata.yaml
68+
# 静态图预测
69+
python -u ../../../tools/static_infer.py -m config.yaml
70+
```
71+
72+
## 模型组网
73+
74+
### 效果复现
75+
为了方便使用者能够快速的跑通每一个模型,我们在每个模型下都提供了样例数据。如果需要复现readme中的效果,请按如下步骤依次操作即可。
76+
在全量数据下模型的指标如下:
77+
| 模型 | auc_marital | batch_size | epoch_num | Time of each epoch |
78+
| :------| :------ | :------ | :------| :------ |
79+
| MMOE | 0.99 | 32 | 100 | 约1分钟 |
80+
81+
1. 确认您当前所在目录为PaddleRec/models/multitask/share_bottom
82+
2. 进入paddlerec/datasets/census目录下,执行该脚本,会从国内源的服务器上下载我们预处理完成的census全量数据集,并解压到指定文件夹。
83+
``` bash
84+
cd ../../../datasets/census
85+
sh run.sh
86+
```
87+
3. 切回模型目录,执行命令运行全量数据
88+
```bash
89+
cd - # 切回模型目录
90+
# 动态图训练
91+
python -u ../../../tools/trainer.py -m config_bigdata.yaml # 全量数据运行config_bigdata.yaml
92+
python -u ../../../tools/infer.py -m config_bigdata.yaml # 全量数据运行config_bigdata.yaml
93+
```
94+
95+
## 进阶使用
96+
97+
## FAQ
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
import numpy as np
17+
18+
from paddle.io import IterableDataset
19+
20+
21+
class RecDataset(IterableDataset):
22+
def __init__(self, file_list, config):
23+
super(RecDataset, self).__init__()
24+
self.file_list = file_list
25+
self.config = config
26+
27+
def __iter__(self):
28+
full_lines = []
29+
self.data = []
30+
for file in self.file_list:
31+
with open(file, "r") as rf:
32+
for l in rf:
33+
l = l.strip().split(',')
34+
l = list(map(float, l))
35+
label_income = []
36+
label_marital = []
37+
data = l[2:]
38+
if int(l[1]) == 0:
39+
label_income = [0]
40+
elif int(l[1]) == 1:
41+
label_income = [1]
42+
if int(l[0]) == 0:
43+
label_marital = [0]
44+
elif int(l[0]) == 1:
45+
label_marital = [1]
46+
output_list = []
47+
output_list.append(np.array(data).astype('float32'))
48+
output_list.append(np.array(label_income).astype('int64'))
49+
output_list.append(np.array(label_marital).astype('int64'))
50+
yield output_list
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
runner:
16+
train_data_dir: "data/train"
17+
train_reader_path: "census_reader" # importlib format
18+
use_gpu: False
19+
use_auc: True
20+
train_batch_size: 2
21+
epochs: 3
22+
print_interval: 2
23+
#model_init_path: "output_model/0" # init model
24+
model_save_path: "output_model_share_btm"
25+
test_data_dir: "data/test"
26+
infer_batch_size: 2
27+
infer_reader_path: "census_reader" # importlib format
28+
infer_load_path: "output_model_share_btm"
29+
infer_start_epoch: 0
30+
infer_end_epoch: 3
31+
32+
hyper_parameters:
33+
feature_size: 499
34+
bottom_size: 117
35+
task_num: 2
36+
tower_size: 8
37+
optimizer:
38+
class: adam
39+
learning_rate: 0.001
40+
strategy: async
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
runner:
16+
train_data_dir: "../../../datasets/census/train_all"
17+
train_reader_path: "census_reader" # importlib format
18+
use_gpu: False
19+
use_auc: True
20+
train_batch_size: 32
21+
epochs: 100
22+
print_interval: 100
23+
#model_init_path: "output_model/0" # init model
24+
model_save_path: "output_model_share_btm_all"
25+
test_data_dir: "../../../datasets/census/test_all"
26+
infer_batch_size: 32
27+
infer_reader_path: "census_reader" # importlib format
28+
infer_load_path: "output_model_share_btm_all"
29+
infer_start_epoch: 0
30+
infer_end_epoch: 100
31+
32+
33+
hyper_parameters:
34+
feature_size: 499
35+
bottom_size: 117
36+
task_num: 2
37+
tower_size: 8
38+
optimizer:
39+
class: adam
40+
learning_rate: 0.001
41+
strategy: async

0 commit comments

Comments
 (0)