|
| 1 | +# DSelect-k(DSelect-k: Differentiable Selection in the Mixture of Experts with Applications to Multi-Task Learning) |
| 2 | + |
| 3 | +代码请参考:[DSelect_K](https://github.com/PaddlePaddle/PaddleRec/blob/master/models/multitask/dselect_k) |
| 4 | +如果我们的代码对您有用,还请点个star啊~ |
| 5 | + |
| 6 | +## 内容 |
| 7 | + |
| 8 | +- [模型简介](#模型简介) |
| 9 | +- [数据准备](#数据准备) |
| 10 | +- [运行环境](#运行环境) |
| 11 | +- [快速开始](#快速开始) |
| 12 | +- [效果复现](#效果复现) |
| 13 | +- [进阶使用](#进阶使用) |
| 14 | +- [FAQ](#FAQ) |
| 15 | + |
| 16 | +## 模型简介 |
| 17 | + |
| 18 | +`MoE(Mixture of Experts)` 架构在改善多任务学习 MTL(Multi-Task Learning) 中的参数共享和扩展大容量神经网络方面显示出良好的效果。SOTA 的 MoE |
| 19 | +类模型使用一个可训练的稀疏门控来为每个输入实例选择一个专家子集。虽然概念上可行有效,但现有的稀疏门控例如 Top-K, 并不平滑(意味着不可导)。在使用基于梯度的方法进行训练时,平滑度的缺失会导致收敛和统计性能问题。 |
| 20 | +本文基于二进制编码方法提出了 `DSelect-k: a continuously differentiable and sparse gate for MoE`, 解决了现有稀疏门控不可导的弊端,可以根据梯度下降类方法进行训练。 |
| 21 | + |
| 22 | + |
| 23 | + |
| 24 | +上图是 MoE 和 MMoE 的结构图, 本文所提出的 DSelect-k 模型用于从 N 个专家中选择 Top-K 个进行后续任务预测。其主要有两种模式,1)Static; 2) Per-example; |
| 25 | +前者不感知输入,所有实例会选择同样的专家子集,后者恰恰相反。 |
| 26 | + |
| 27 | +```text |
| 28 | +@article{hazimeh2021dselectk, |
| 29 | + title={DSelect-k: Differentiable Selection in the Mixture of Experts with Applications to Multi-Task Learning}, |
| 30 | + author={Hussein Hazimeh and Zhe Zhao and Aakanksha Chowdhery and Maheswaran Sathiamoorthy and Yihua Chen and Rahul Mazumder and Lichan Hong and Ed H. Chi}, |
| 31 | + year={2021}, |
| 32 | + eprint={2106.03760}, |
| 33 | + archivePrefix={arXiv}, |
| 34 | + primaryClass={cs.LG} |
| 35 | +} |
| 36 | +``` |
| 37 | + |
| 38 | +## 数据准备 |
| 39 | + |
| 40 | +训练及测试数据集选用的是 [Multi-MNIST](https://paperswithcode.com/dataset/multimnist) |
| 41 | +数据集,该数据集是在 [Dynamic Routing Between Capsules](https://paperswithcode.com/paper/dynamic-routing-between-capsules) |
| 42 | +首次介绍提出,后续一些 MTL 论文大多沿用该数据集。 |
| 43 | + |
| 44 | +> The MultiMNIST dataset is generated from MNIST. The training and tests are generated by overlaying a digit on top of another digit from the same set (training or test) but different class. Each digit is shifted up to 4 pixels in each direction resulting in a 36×36 image. |
| 45 | +
|
| 46 | + |
| 47 | + |
| 48 | +上图是其中一张图片,与经典数据集 MNIST 不同,该图片上包含两位数字,左上和右下,对应两个多分类任务。数据集划分为训练集、验证集、测试集,数量分别为 100000、20000、20000。 |
| 49 | + |
| 50 | +在 PaperswithCode 网站上检索到 【NeurlPS 2019】[Pareto Multi-Task Learning](https://arxiv.org/pdf/1912.12854v1.pdf) 公布了该数据集, |
| 51 | +下载链接: https://drive.google.com/drive/folders/1VnmCmBAVh8f_BKJg1KYx-E137gBLXbGG。 |
| 52 | + |
| 53 | + |
| 54 | +## 运行环境 |
| 55 | + |
| 56 | +PaddlePaddle>=2.1 |
| 57 | + |
| 58 | +python 2.7/3.5/3.6/3.7 |
| 59 | + |
| 60 | +os : windows/linux/macos |
| 61 | + |
| 62 | +## 快速开始 |
| 63 | + |
| 64 | +本文提供了样例数据可以供您快速体验,在任意目录下均可执行。在 dselect_k 模型目录的快速执行命令如下: |
| 65 | + |
| 66 | +```bash |
| 67 | +# 进入模型目录 |
| 68 | +# cd models/multitask/dselect_k # 在任意目录均可运行 |
| 69 | +# 动态图训练 |
| 70 | +python -u ../../../tools/trainer.py -m config.yaml # 全量数据运行config_bigdata.yaml |
| 71 | +# 动态图预测 |
| 72 | +python -u ../../../tools/infer.py -m config.yaml |
| 73 | + |
| 74 | +# 静态图训练 |
| 75 | +python -u ../../../tools/static_trainer.py -m config.yaml # 全量数据运行config_bigdata.yaml |
| 76 | +# 静态图预测 |
| 77 | +python -u ../../../tools/static_infer.py -m config.yaml |
| 78 | +``` |
| 79 | + |
| 80 | +## 效果复现 |
| 81 | + |
| 82 | +为了方便使用者能够快速的跑通每一个模型,我们在每个模型下都提供了样例数据。如果需要复现 readme 中的效果,请按如下步骤依次操作即可。 在全量数据下模型的指标如下: |
| 83 | + |
| 84 | + |
| 85 | +| 模型 | Accuracy1 | Accuracy2 | batch_size | epoch_num| Time of each epoch | |
| 86 | +| :------| :------ | :------ | :------ | :------| :------ | |
| 87 | +| DSelect-k | 0.930460 | 0.916088 | 256 | 100 | 约 0.5 小时 | |
| 88 | + |
| 89 | +1. 确认您当前所在目录为 PaddleRec/models/multitask/dselect_k |
| 90 | + |
| 91 | +2. 进入 paddlerec/datasets/Multi_MNIST_DselectK 目录下,执行该脚本,会从国内源的服务器上下载数据集,并解压到指定文件夹。 |
| 92 | + |
| 93 | +``` bash |
| 94 | +cd ../../../datasets/Multi_MNIST_DselectK |
| 95 | +sh run.sh |
| 96 | +``` |
| 97 | + |
| 98 | +3. 切回模型目录,执行命令运行全量数据 |
| 99 | + |
| 100 | +```bash |
| 101 | +# 切回模型目录 PaddleRec/models/multitask/dselect_k |
| 102 | +# 动态图训练 |
| 103 | +python -u ../../../tools/trainer.py -m config_bigdata.yaml # 全量数据运行config_bigdata.yaml |
| 104 | +python -u ../../../tools/infer.py -m config_bigdata.yaml # 全量数据运行config_bigdata.yaml |
| 105 | +``` |
| 106 | + |
| 107 | +## 进阶使用 |
| 108 | + |
| 109 | +## FAQ |
0 commit comments