Skip to content

Commit 59dc03f

Browse files
authored
Merge pull request #684 from wangzhen38/lwfx_dselectk
Add DSelect_K
2 parents 380c47b + edc66ed commit 59dc03f

File tree

19 files changed

+1161
-0
lines changed

19 files changed

+1161
-0
lines changed

README_CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ python -u tools/static_trainer.py -m models/rank/dnn/config.yaml # 静态图训
165165
| 多任务 | [MMOE](models/multitask/mmoe/)([文档](https://paddlerec.readthedocs.io/en/latest/models/multitask/mmoe.html)) | [Python CPU/GPU](https://aistudio.baidu.com/aistudio/projectdetail/3238934) ||| >=2.1.0 | [KDD 2018][Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts](https://dl.acm.org/doi/abs/10.1145/3219819.3220007) |
166166
| 多任务 | [ShareBottom](models/multitask/share_bottom/)([文档](https://paddlerec.readthedocs.io/en/latest/models/multitask/share_bottom.html)) | [Python CPU/GPU](https://aistudio.baidu.com/aistudio/projectdetail/3238943) ||| >=2.1.0 | [1998][Multitask learning](http://reports-archive.adm.cs.cmu.edu/anon/1997/CMU-CS-97-203.pdf) |
167167
| 多任务 | [Maml](models/multitask/maml/)([文档](https://paddlerec.readthedocs.io/en/latest/models/multitask/maml.html)) | [Python CPU/GPU](https://aistudio.baidu.com/aistudio/projectdetail/3238412) | x | x | >=2.1.0 | [PMLR 2017][Model-agnostic meta-learning for fast adaptation of deep networks](https://arxiv.org/pdf/1703.03400.pdf) |
168+
| 多任务 | [DSelect_K](models/multitask/dselect_k/)([文档](https://paddlerec.readthedocs.io/en/latest/models/multitask/dselect_k.html)) | - | x | x | >=2.1.0 | [NeurIPS 2021][DSelect-k: Differentiable Selection in the Mixture of Experts with Applications to Multi-Task Learning](https://arxiv.org/pdf/2106.03760v3.pdf) |
168169
| 重排序 | [Listwise](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5/models/rerank/listwise/) | - || x | [1.8.5](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5) | [2019][Sequential Evaluation and Generation Framework for Combinatorial Recommender System](https://arxiv.org/pdf/1902.00245.pdf) |
169170

170171

README_EN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ python -u tools/static_trainer.py -m models/rank/dnn/config.yaml # Training wit
155155
| Multi-Task | [MMOE](models/multitask/mmoe/)<br>([doc](https://paddlerec.readthedocs.io/en/latest/models/multitask/mmoe.html)) | [Python CPU/GPU](https://aistudio.baidu.com/aistudio/projectdetail/3238934) ||| >=2.1.0 | [KDD 2018][Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts](https://dl.acm.org/doi/abs/10.1145/3219819.3220007) |
156156
| Multi-Task | [ShareBottom](models/multitask/share_bottom/)<br>([doc](https://paddlerec.readthedocs.io/en/latest/models/multitask/share_bottom.html)) | [Python CPU/GPU](https://aistudio.baidu.com/aistudio/projectdetail/3238943) ||| >=2.1.0 | [1998][Multitask learning](http://reports-archive.adm.cs.cmu.edu/anon/1997/CMU-CS-97-203.pdf) |
157157
| Multi-Task | [Maml](models/multitask/maml/)<br>([doc](https://paddlerec.readthedocs.io/en/latest/models/multitask/maml.html)) | [Python CPU/GPU](https://aistudio.baidu.com/aistudio/projectdetail/3238412) | x | x | >=2.1.0 | [PMLR 2017][Model-agnostic meta-learning for fast adaptation of deep networks](https://arxiv.org/pdf/1703.03400.pdf) |
158+
| Multi-Task | [DSelect_K](models/multitask/dselect_k/)<br>([doc](https://paddlerec.readthedocs.io/en/latest/models/multitask/dselect_k.html)) | - | x | x | >=2.1.0 | [NeurIPS 2021][DSelect-k: Differentiable Selection in the Mixture of Experts with Applications to Multi-Task Learning](https://arxiv.org/pdf/2106.03760v3.pdf) |
158159
| Re-Rank | [Listwise](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5/models/rerank/listwise/) | - || x | [1.8.5](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5) | [2019][Sequential Evaluation and Generation Framework for Combinatorial Recommender System](https://arxiv.org/pdf/1902.00245.pdf) |
159160

160161
<h2 align="center">Community</h2>

contributor.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,9 @@
1212
| [BERT4REC](models/rank/bert4rec/) | [jinweiluo](https://github.com/jinweiluo) | https://github.com/PaddlePaddle/PaddleRec/pull/624 | 论文复现赛第四期 |
1313
| [FAT_DeepFFM](models/rank/fat_deepffm/) | [LinJayan](https://github.com/LinJayan) | https://github.com/PaddlePaddle/PaddleRec/pull/651 | 论文复现赛第四期 |
1414
| [DeepRec](models/rank/deeprec/) | [chenjiyan2001](https://github.com/chenjiyan2001) | https://github.com/PaddlePaddle/PaddleRec/pull/647 | 论文复现赛第五期 |
15+
| [ENSFM](models/recal/ensfm/) | [renmada](https://github.com/renmada) | https://github.com/PaddlePaddle/PaddleRec/pull/618 | 论文复现赛第五期 |
16+
| [TiSAS](models/recal/tisas/) | [renmada](https://github.com/renmada) | https://github.com/PaddlePaddle/PaddleRec/pull/625 | 论文复现赛第五期 |
17+
| [AutoFIS](models/rank/autofis/) | [renmada](https://github.com/renmada) | https://github.com/PaddlePaddle/PaddleRec/pull/660 | 论文复现赛第五期 |
18+
| [Dselect_K](models/multitask/dselect_k/) | [Andy1314Chen](https://github.com/Andy1314Chen) | https://github.com/PaddlePaddle/PaddleRec/pull/671 | 论文复现赛第五期 |
1519

1620
</div>
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
wget -c https://paddlerec.bj.bcebos.com/datasets/Multi_Mnist_Dselet_K/multi_mnist.zip
2+
unzip multi_mnist.zip
3+
4+
mkdir train test
5+
mv train.pickle ./train/train.pickle
6+
mv test.pickle ./test/test.pickle
7+
rm -rf multi_mnist.zip train.pickle test.pickle

doc/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
models/multitask/mmoe.md
8585
models/multitask/ple.md
8686
models/multitask/share_bottom.md
87+
models/multitask/dselect_k.md
8788

8889
.. toctree::
8990
:maxdepth: 1
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# DSelect-k(DSelect-k: Differentiable Selection in the Mixture of Experts with Applications to Multi-Task Learning)
2+
3+
代码请参考:[DSelect_K](https://github.com/PaddlePaddle/PaddleRec/blob/master/models/multitask/dselect_k)
4+
如果我们的代码对您有用,还请点个star啊~
5+
6+
## 内容
7+
8+
- [模型简介](#模型简介)
9+
- [数据准备](#数据准备)
10+
- [运行环境](#运行环境)
11+
- [快速开始](#快速开始)
12+
- [效果复现](#效果复现)
13+
- [进阶使用](#进阶使用)
14+
- [FAQ](#FAQ)
15+
16+
## 模型简介
17+
18+
`MoE(Mixture of Experts)` 架构在改善多任务学习 MTL(Multi-Task Learning) 中的参数共享和扩展大容量神经网络方面显示出良好的效果。SOTA 的 MoE
19+
类模型使用一个可训练的稀疏门控来为每个输入实例选择一个专家子集。虽然概念上可行有效,但现有的稀疏门控例如 Top-K, 并不平滑(意味着不可导)。在使用基于梯度的方法进行训练时,平滑度的缺失会导致收敛和统计性能问题。
20+
本文基于二进制编码方法提出了 `DSelect-k: a continuously differentiable and sparse gate for MoE`, 解决了现有稀疏门控不可导的弊端,可以根据梯度下降类方法进行训练。
21+
22+
![](https://tva1.sinaimg.cn/large/008i3skNly1gy3rpouuc6j30rw0d63zy.jpg)
23+
24+
上图是 MoE 和 MMoE 的结构图, 本文所提出的 DSelect-k 模型用于从 N 个专家中选择 Top-K 个进行后续任务预测。其主要有两种模式,1)Static; 2) Per-example;
25+
前者不感知输入,所有实例会选择同样的专家子集,后者恰恰相反。
26+
27+
```text
28+
@article{hazimeh2021dselectk,
29+
title={DSelect-k: Differentiable Selection in the Mixture of Experts with Applications to Multi-Task Learning},
30+
author={Hussein Hazimeh and Zhe Zhao and Aakanksha Chowdhery and Maheswaran Sathiamoorthy and Yihua Chen and Rahul Mazumder and Lichan Hong and Ed H. Chi},
31+
year={2021},
32+
eprint={2106.03760},
33+
archivePrefix={arXiv},
34+
primaryClass={cs.LG}
35+
}
36+
```
37+
38+
## 数据准备
39+
40+
训练及测试数据集选用的是 [Multi-MNIST](https://paperswithcode.com/dataset/multimnist)
41+
数据集,该数据集是在 [Dynamic Routing Between Capsules](https://paperswithcode.com/paper/dynamic-routing-between-capsules)
42+
首次介绍提出,后续一些 MTL 论文大多沿用该数据集。
43+
44+
> The MultiMNIST dataset is generated from MNIST. The training and tests are generated by overlaying a digit on top of another digit from the same set (training or test) but different class. Each digit is shifted up to 4 pixels in each direction resulting in a 36×36 image.
45+
46+
![](https://tva1.sinaimg.cn/large/008i3skNly1gy3ryidh3hj30f40ea3yp.jpg)
47+
48+
上图是其中一张图片,与经典数据集 MNIST 不同,该图片上包含两位数字,左上和右下,对应两个多分类任务。数据集划分为训练集、验证集、测试集,数量分别为 100000、20000、20000。
49+
50+
在 PaperswithCode 网站上检索到 【NeurlPS 2019】[Pareto Multi-Task Learning](https://arxiv.org/pdf/1912.12854v1.pdf) 公布了该数据集,
51+
下载链接: https://drive.google.com/drive/folders/1VnmCmBAVh8f_BKJg1KYx-E137gBLXbGG。
52+
53+
54+
## 运行环境
55+
56+
PaddlePaddle>=2.1
57+
58+
python 2.7/3.5/3.6/3.7
59+
60+
os : windows/linux/macos
61+
62+
## 快速开始
63+
64+
本文提供了样例数据可以供您快速体验,在任意目录下均可执行。在 dselect_k 模型目录的快速执行命令如下:
65+
66+
```bash
67+
# 进入模型目录
68+
# cd models/multitask/dselect_k # 在任意目录均可运行
69+
# 动态图训练
70+
python -u ../../../tools/trainer.py -m config.yaml # 全量数据运行config_bigdata.yaml
71+
# 动态图预测
72+
python -u ../../../tools/infer.py -m config.yaml
73+
74+
# 静态图训练
75+
python -u ../../../tools/static_trainer.py -m config.yaml # 全量数据运行config_bigdata.yaml
76+
# 静态图预测
77+
python -u ../../../tools/static_infer.py -m config.yaml
78+
```
79+
80+
## 效果复现
81+
82+
为了方便使用者能够快速的跑通每一个模型,我们在每个模型下都提供了样例数据。如果需要复现 readme 中的效果,请按如下步骤依次操作即可。 在全量数据下模型的指标如下:
83+
84+
85+
| 模型 | Accuracy1 | Accuracy2 | batch_size | epoch_num| Time of each epoch |
86+
| :------| :------ | :------ | :------ | :------| :------ |
87+
| DSelect-k | 0.930460 | 0.916088 | 256 | 100 | 约 0.5 小时 |
88+
89+
1. 确认您当前所在目录为 PaddleRec/models/multitask/dselect_k
90+
91+
2. 进入 paddlerec/datasets/Multi_MNIST_DselectK 目录下,执行该脚本,会从国内源的服务器上下载数据集,并解压到指定文件夹。
92+
93+
``` bash
94+
cd ../../../datasets/Multi_MNIST_DselectK
95+
sh run.sh
96+
```
97+
98+
3. 切回模型目录,执行命令运行全量数据
99+
100+
```bash
101+
# 切回模型目录 PaddleRec/models/multitask/dselect_k
102+
# 动态图训练
103+
python -u ../../../tools/trainer.py -m config_bigdata.yaml # 全量数据运行config_bigdata.yaml
104+
python -u ../../../tools/infer.py -m config_bigdata.yaml # 全量数据运行config_bigdata.yaml
105+
```
106+
107+
## 进阶使用
108+
109+
## FAQ

doc/source/readme.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
[mmoe](https://paddlerec.readthedocs.io/en/latest/models/multitask/mmoe.html)
2323
[ple](https://paddlerec.readthedocs.io/en/latest/models/multitask/ple.html)
2424
[share_bottom](https://paddlerec.readthedocs.io/en/latest/models/multitask/share_bottom.html)
25+
[dselect_k](https://paddlerec.readthedocs.io/en/latest/models/multitask/dselect_k.html)
2526

2627
## 排序模型
2728
[bst](https://paddlerec.readthedocs.io/en/latest/models/rank/bst.html)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
runner:
16+
train_data_dir: "data/sample_data"
17+
train_reader_path: "multiMNIST_reader" # importlib format
18+
use_gpu: False
19+
use_auc: True
20+
train_batch_size: 8
21+
epochs: 10
22+
print_interval: 2
23+
#model_init_path: "output_model/0" # init model
24+
model_save_path: "output_model_dselect_k"
25+
test_data_dir: "data/sample_data"
26+
infer_batch_size: 8
27+
infer_reader_path: "multiMNIST_reader" # importlib format
28+
infer_load_path: "output_model_dselect_k"
29+
infer_start_epoch: 0
30+
infer_end_epoch: 10
31+
#use inference save model
32+
use_inference: False
33+
save_inference_feed_varnames: ["input"]
34+
save_inference_fetch_varnames: ["slice_0.tmp_0", "slice_1.tmp_0"]
35+
36+
hyper_parameters:
37+
feature_size: 1296
38+
top_k: 2
39+
expert_num: 8
40+
gate_num: 2
41+
expert_size: 50
42+
tower_size: 8
43+
optimizer:
44+
class: adam
45+
learning_rate: 0.001
46+
strategy: async
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
runner:
16+
train_data_dir: "../../../datasets/Multi_MNIST_DselectK/train"
17+
train_reader_path: "multiMNIST_reader" # importlib format
18+
use_gpu: True
19+
use_auc: True
20+
train_batch_size: 256
21+
epochs: 100
22+
print_interval: 32
23+
#model_init_path: "output_model/0" # init model
24+
model_save_path: "output_model_dselect_k_all"
25+
test_data_dir: "../../../datasets/Multi_MNIST_DselectK/test"
26+
infer_batch_size: 20000
27+
infer_reader_path: "multiMNIST_reader" # importlib format
28+
infer_load_path: "output_model_dselect_k_all"
29+
infer_start_epoch: 0
30+
infer_end_epoch: 100
31+
#use inference save model
32+
use_inference: False
33+
save_inference_feed_varnames: ["input"]
34+
save_inference_fetch_varnames: ["slice_0.tmp_0", "slice_1.tmp_0"]
35+
36+
hyper_parameters:
37+
feature_size: 1296
38+
top_k: 2
39+
expert_num: 8
40+
gate_num: 2
41+
expert_size: 50
42+
tower_size: 8
43+
optimizer:
44+
class: adam
45+
learning_rate: 0.001
46+
strategy: async

0 commit comments

Comments
 (0)