Skip to content

Commit 9364512

Browse files
authored
Merge pull request #426 from frankwhzhang/ple_0430
Ple 2.0
2 parents 9b286ef + 1eb742b commit 9364512

File tree

10 files changed

+742
-0
lines changed

10 files changed

+742
-0
lines changed

models/multitask/ple/README.md

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# MMOE
2+
3+
以下是本例的简要目录结构及说明:
4+
5+
```
6+
├── data # 文档
7+
├── train #训练数据
8+
├── train_data.txt
9+
├── test #测试数据
10+
├── test_data.txt
11+
├── __init__.py
12+
├── README.md #文档
13+
├── config.yaml # sample数据配置
14+
├── config_bigdata.yaml # 全量数据配置
15+
├── census_reader.py # 数据读取程序
16+
├── net.py # 模型核心组网(动静统一)
17+
├── static_model.py # 构建静态图
18+
├── dygraph_model.py # 构建动态图
19+
```
20+
21+
注:在阅读该示例前,建议您先了解以下内容:
22+
23+
[paddlerec入门教程](https://github.com/PaddlePaddle/PaddleRec/blob/master/README.md)
24+
25+
## 内容
26+
27+
- [模型简介](#模型简介)
28+
- [数据准备](#数据准备)
29+
- [运行环境](#运行环境)
30+
- [快速开始](#快速开始)
31+
- [模型组网](#模型组网)
32+
- [效果复现](#效果复现)
33+
- [进阶使用](#进阶使用)
34+
- [FAQ](#FAQ)
35+
36+
## 模型简介
37+
多任务模型通过学习不同任务的联系和差异,可提高每个任务的学习效率和质量。多任务学习的的框架广泛采用shared-bottom的结构,不同任务间共用底部的隐层。这种结构本质上可以减少过拟合的风险,但是效果上可能受到任务差异和数据分布带来的影响。 论文[《Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts》]( https://www.kdd.org/kdd2018/accepted-papers/view/modeling-task-relationships-in-multi-task-learning-with-multi-gate-mixture- )中提出了一个Multi-gate Mixture-of-Experts(MMOE)的多任务学习结构。
38+
39+
## 数据准备
40+
我们在开源数据集Census-income Data上验证模型效果,在模型目录的data目录下为您准备了快速运行的示例数据,若需要使用全量数据可以参考下方[效果复现](#效果复现)部分.
41+
数据的格式如下:
42+
生成的格式以逗号为分割点
43+
```
44+
0,0,73,0,0,0,0,1700.09,0,0
45+
```
46+
47+
## 运行环境
48+
PaddlePaddle>=2.0
49+
50+
python 2.7/3.5/3.6/3.7
51+
52+
os : windows/linux/macos
53+
54+
## 快速开始
55+
本文提供了样例数据可以供您快速体验,在任意目录下均可执行。在mmoe模型目录的快速执行命令如下:
56+
```bash
57+
# 进入模型目录
58+
# cd models/multitask/mmoe # 在任意目录均可运行
59+
# 动态图训练
60+
python -u ../../../tools/trainer.py -m config.yaml # 全量数据运行config_bigdata.yaml
61+
# 动态图预测
62+
python -u ../../../tools/infer.py -m config.yaml
63+
64+
# 静态图训练
65+
python -u ../../../tools/static_trainer.py -m config.yaml # 全量数据运行config_bigdata.yaml
66+
# 静态图预测
67+
python -u ../../../tools/static_infer.py -m config.yaml
68+
```
69+
70+
## 模型组网
71+
MMOE模型刻画了任务相关性,基于共享表示来学习特定任务的函数,避免了明显增加参数的缺点。模型的主要组网结构如下:
72+
[MMoE](https://dl.acm.org/doi/abs/10.1145/3219819.3220007):
73+
<p align="center">
74+
<img align="center" src="../../../doc/imgs/mmoe.png">
75+
<p>
76+
77+
### 效果复现
78+
为了方便使用者能够快速的跑通每一个模型,我们在每个模型下都提供了样例数据。如果需要复现readme中的效果,请按如下步骤依次操作即可。
79+
在全量数据下模型的指标如下:
80+
| 模型 | auc_marital | batch_size | epoch_num | Time of each epoch |
81+
| :------| :------ | :------ | :------| :------ |
82+
| MMOE | 0.99 | 32 | 100 | 约1分钟 |
83+
84+
1. 确认您当前所在目录为PaddleRec/models/multitask/mmoe
85+
2. 进入paddlerec/datasets/census目录下,执行该脚本,会从国内源的服务器上下载我们预处理完成的census全量数据集,并解压到指定文件夹。
86+
``` bash
87+
cd ../../../datasets/census
88+
sh run.sh
89+
```
90+
3. 切回模型目录,执行命令运行全量数据
91+
```bash
92+
cd - # 切回模型目录
93+
# 动态图训练
94+
python -u ../../../tools/trainer.py -m config_bigdata.yaml # 全量数据运行config_bigdata.yaml
95+
python -u ../../../tools/infer.py -m config_bigdata.yaml # 全量数据运行config_bigdata.yaml
96+
```
97+
98+
## 进阶使用
99+
100+
## FAQ

models/multitask/ple/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
import numpy as np
17+
18+
from paddle.io import IterableDataset
19+
20+
21+
class RecDataset(IterableDataset):
22+
def __init__(self, file_list, config):
23+
super(RecDataset, self).__init__()
24+
self.file_list = file_list
25+
self.config = config
26+
27+
def __iter__(self):
28+
full_lines = []
29+
self.data = []
30+
for file in self.file_list:
31+
with open(file, "r") as rf:
32+
for l in rf:
33+
l = l.strip().split(',')
34+
l = list(map(float, l))
35+
label_income = []
36+
label_marital = []
37+
data = l[2:]
38+
if int(l[1]) == 0:
39+
label_income = [0]
40+
elif int(l[1]) == 1:
41+
label_income = [1]
42+
if int(l[0]) == 0:
43+
label_marital = [0]
44+
elif int(l[0]) == 1:
45+
label_marital = [1]
46+
output_list = []
47+
output_list.append(np.array(data).astype('float32'))
48+
output_list.append(np.array(label_income).astype('int64'))
49+
output_list.append(np.array(label_marital).astype('int64'))
50+
yield output_list

models/multitask/ple/config.yaml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
runner:
16+
train_data_dir: "data/train"
17+
train_reader_path: "census_reader" # importlib format
18+
use_gpu: False
19+
use_auc: True
20+
train_batch_size: 2
21+
epochs: 3
22+
print_interval: 2
23+
#model_init_path: "output_model/0" # init model
24+
model_save_path: "output_model_ple"
25+
test_data_dir: "data/test"
26+
infer_batch_size: 2
27+
infer_reader_path: "census_reader" # importlib format
28+
infer_load_path: "output_model_ple"
29+
infer_start_epoch: 0
30+
infer_end_epoch: 3
31+
32+
hyper_parameters:
33+
feature_size: 499
34+
task_num: 2
35+
shared_num: 2
36+
exp_per_task: 3
37+
level_number: 1
38+
expert_size: 16
39+
tower_size: 8
40+
optimizer:
41+
class: adam
42+
learning_rate: 0.001
43+
strategy: async
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
runner:
16+
train_data_dir: "../../../datasets/census/train_all"
17+
train_reader_path: "census_reader" # importlib format
18+
use_gpu: False
19+
use_auc: True
20+
train_batch_size: 32
21+
epochs: 100
22+
print_interval: 100
23+
#model_init_path: "output_model/0" # init model
24+
model_save_path: "output_model_ple_all"
25+
test_data_dir: "../../../datasets/census/test_all"
26+
infer_batch_size: 32
27+
infer_reader_path: "census_reader" # importlib format
28+
infer_load_path: "output_model_ple_all"
29+
infer_start_epoch: 0
30+
infer_end_epoch: 100
31+
32+
33+
hyper_parameters:
34+
feature_size: 499
35+
task_num: 2
36+
shared_num: 2
37+
exp_per_task: 3
38+
level_number: 1
39+
expert_size: 16
40+
tower_size: 8
41+
optimizer:
42+
class: adam
43+
learning_rate: 0.001
44+
strategy: async

0 commit comments

Comments
 (0)