|
| 1 | +# BART |
| 2 | + |
| 3 | +## 模型简介 |
| 4 | + |
| 5 | +BART是一种Seq2Seq结构的降噪自编码器,通过增加噪声来破环文本然后重建原文本来训练模型。它使用一个标准的Transformer结构,可以被看作泛化的BERT(由于是双向编码器),GPT(由于是从左到右解码器),和一些其他的预训练模型结构。 |
| 6 | + |
| 7 | +本项目是BART在 PaddlePaddle 2.2上开源实现的文本摘要的例子,包含了在[CNN/DailyMail](https://arxiv.org/pdf/1704.04368.pdf)数据集上微调和生成的代码。 |
| 8 | + |
| 9 | +## 快速开始 |
| 10 | + |
| 11 | +### 环境依赖 |
| 12 | + |
| 13 | +- nltk |
| 14 | +- rouge_score |
| 15 | + |
| 16 | +安装方式:`pip install -r requirements.txt` |
| 17 | + |
| 18 | +### 代码结构说明 |
| 19 | + |
| 20 | +以下是本项目主要代码结构及说明: |
| 21 | + |
| 22 | +```text |
| 23 | +. |
| 24 | +├── run_summarization.py # 模型finetune主程序入口 |
| 25 | +├── generate.py # 模型生成主程序入口 |
| 26 | +├── utils.py # 定义参数及一些工具函数 |
| 27 | +└── README.md # 文档说明 |
| 28 | +``` |
| 29 | + |
| 30 | +### 数据准备 |
| 31 | + |
| 32 | +**CNN/DailyMail**数据集是一个英文数据集,包含CNN和《每日邮报》记者撰写的30多万篇独特新闻文章,常用来做文本摘要。 |
| 33 | + |
| 34 | +为了方便用户快速测试,PaddleNLP Dataset API内置了CNN/DailyMail数据集,一键即可完成数据集加载,示例代码如下: |
| 35 | + |
| 36 | +```python |
| 37 | +from paddlenlp.datasets import load_dataset |
| 38 | +train_set, dev_set, test_set = load_dataset("cnn_dailymail", splits=["train", "dev", "test"]) |
| 39 | +``` |
| 40 | + |
| 41 | +### 模型训练 |
| 42 | + |
| 43 | +运行如下命令即可在训练集上进行finetune,并在验证集上进行验证 |
| 44 | + |
| 45 | +```shell |
| 46 | +# GPU启动,参数`--gpus`指定训练所用的GPU卡号,可以是单卡,也可以多卡 |
| 47 | +# 例如使用1号和2号卡,则:`--gpu 1,2` |
| 48 | +unset CUDA_VISIBLE_DEVICES |
| 49 | +python -m paddle.distributed.launch --gpus 1,2 run_summarization.py \ |
| 50 | + --model_name_or_path=bart-base \ |
| 51 | + --dataset_name=cnn_dailymail \ |
| 52 | + --output_dir=output \ |
| 53 | + --max_source_length=1024 \ |
| 54 | + --max_target_length=142 \ |
| 55 | + --learning_rate=1e-4 \ |
| 56 | + --num_train_epochs=6 \ |
| 57 | + --logging_steps=100 \ |
| 58 | + --save_steps=1000 \ |
| 59 | + --seed=42 \ |
| 60 | + --train_batch_size=20 \ |
| 61 | + --eval_batch_size=64 \ |
| 62 | + --warmup_proportion=0.1 \ |
| 63 | + --ignore_pad_token_for_loss=True \ |
| 64 | + --device=gpu |
| 65 | +``` |
| 66 | + |
| 67 | +其中参数释义如下: |
| 68 | +- `gpus` 指示了训练所用的GPU |
| 69 | + |
| 70 | +- `model_name_or_path` 指示了finetune使用的预训练模型,可以是PaddleNLP提供的预训练模型,或者是本地的模型。如果使用本地的模型,则配置为本地模型的目录地址,例如: ./checkpoints/model_xx/,目录中需包含paddle模型参数model_state.pdparams。如果使用PaddleNLP提供的预训练模型,可以选择下面其中之一。 |
| 71 | + |
| 72 | + | PaddleNLP提供的预训练模型 | |
| 73 | + |---------------------------------| |
| 74 | + | bart-base | |
| 75 | + | bart-large | |
| 76 | + |
| 77 | +- `dataset_name` 表示训练的数据集。 |
| 78 | + |
| 79 | +- `output_dir` 表示模型的保存路径。 |
| 80 | + |
| 81 | +- `max_source_length` 表示输入article的最大长度。 |
| 82 | + |
| 83 | +- `max_target_length` 表示输入highlights的最大长度。 |
| 84 | + |
| 85 | +- `learning_rate` 表示基础学习率大小,将于learning rate scheduler产生的值相乘作为当前学习率。 |
| 86 | + |
| 87 | +- `num_train_epochs` 表示训练轮数。 |
| 88 | + |
| 89 | +- `logging_steps` 表示日志打印间隔。 |
| 90 | + |
| 91 | +- `save_steps` 表示模型保存及评估间隔。 |
| 92 | + |
| 93 | +- `seed` 表示随机数生成器的种子。 |
| 94 | + |
| 95 | +- `epochs` 表示训练轮数。 |
| 96 | + |
| 97 | +- `train_batch_size` 表示训练**每张卡**上的样本数目。 |
| 98 | + |
| 99 | +- `eval_batch_size` 表示预测**单卡**上的样本数目。 |
| 100 | + |
| 101 | +- `warmup_proportion` 表示warmup_steps所占总步数的比例。学习率逐渐升高到基础学习率(即上面配置的learning_rate)所需要的迭代数。 |
| 102 | + |
| 103 | +- `ignore_pad_token_for_loss` 表示计算loss时忽略padding。 |
| 104 | + |
| 105 | +- `device` 表示使用的设备。 |
| 106 | + |
| 107 | +程序运行时将会自动进行训练和验证,训练过程中会自动保存模型在指定的`output_dir`中。如: |
| 108 | + |
| 109 | +```text |
| 110 | +./output/ |
| 111 | +├── bart_model_1000.pdparams |
| 112 | +│ ├── model_config.json |
| 113 | +│ ├── model_state.pdparams |
| 114 | +│ ├── merges.txt |
| 115 | +│ ├── tokenizer_config.json |
| 116 | +│ └── vocab.json |
| 117 | +└── ... |
| 118 | +``` |
| 119 | + |
| 120 | +**NOTE:** 如需恢复模型训练,只需指定`model_name_or_path`为本地微调模型的路径即可。 |
| 121 | + |
| 122 | +### 模型预测 |
| 123 | + |
| 124 | +运行如下命令即可在验证集上进行测试 |
| 125 | + |
| 126 | +```shell |
| 127 | +# GPU启动,预测仅支持单卡 |
| 128 | +export CUDA_VISIBLE_DEVICES=0 |
| 129 | +python generate.py \ |
| 130 | + --model_name_or_path=bart-base-cnndm-model \ |
| 131 | + --dataset_name=cnn_dailymail \ |
| 132 | + --output_path=generate.txt \ |
| 133 | + --max_source_length=1024 \ |
| 134 | + --max_target_length=142 \ |
| 135 | + --decode_strategy=greedy_search \ |
| 136 | + --top_k=2 \ |
| 137 | + --top_p=1.0 \ |
| 138 | + --num_beams=1 \ |
| 139 | + --length_penalty=0.0 \ |
| 140 | + --batch_size=64 \ |
| 141 | + --seed=42 \ |
| 142 | + --ignore_pad_token_for_loss=True \ |
| 143 | + --logging_steps=100 \ |
| 144 | + --device=gpu |
| 145 | +``` |
| 146 | + |
| 147 | +其中参数释义如下: |
| 148 | +- `model_name_or_path` 指示了预测使用的模型,可以是PaddleNLP提供的预训练模型,或者是本地的模型。如果使用本地的模型,则配置为本地模型的目录地址,例如: ./checkpoints/model_xx/,目录中需包含paddle模型参数model_state.pdparams。如果使用PaddleNLP提供的预训练模型,可以选择下面其中之一。 |
| 149 | + |
| 150 | + | PaddleNLP提供的预训练模型 | |
| 151 | + |---------------------------------| |
| 152 | + | bart-base | |
| 153 | + | bart-large | |
| 154 | + |
| 155 | +- `dataset_name` 表示预测的数据集。 |
| 156 | + |
| 157 | +- `output_path` 表示预测结果的保存路径。 |
| 158 | + |
| 159 | +- `max_source_length` 表示输入article的最大长度。 |
| 160 | + |
| 161 | +- `max_target_length` 表示输入highlights的最大长度。 |
| 162 | + |
| 163 | +- `decode_strategy` 表示预测解码时采取的策略,可选"sampling"、"greedy_search"和"beam_search"之一。 |
| 164 | + |
| 165 | +- `top_k` 表示采用"sampling"解码策略时,token的概率按从大到小排序,生成的token只从前`top_k`个中进行采样。 |
| 166 | + |
| 167 | +- `top_p` 表示采用"sampling"解码策略时,从词表中采样并选择概率之和大于给定阈值`top_p`的token。 |
| 168 | + |
| 169 | +- `num_beams` 表示besm search的beam size。 |
| 170 | + |
| 171 | +- `length_penalty` 表示besm search生成长度的指数惩罚。 |
| 172 | + |
| 173 | +- `batch_size` 表示每次迭代**单卡**上的样本数目。 |
| 174 | + |
| 175 | +- `seed` 表示随机数生成器的种子。 |
| 176 | + |
| 177 | +- `ignore_pad_token_for_loss` 表示训练时计算loss时忽略padding。如果训练时设置为True,那么预测时的label需要还原来计算评估指标。 |
| 178 | + |
| 179 | +- `logging_steps` 表示日志打印间隔。 |
| 180 | + |
| 181 | +- `device` 表示使用的设备。 |
| 182 | + |
| 183 | +程序运行结束后会将预测生成的摘要保存在`output_path`中。同时终端中会输出评估结果。 |
| 184 | + |
| 185 | +采用预训练模型及微调模型在验证集上有如下结果: |
| 186 | + |
| 187 | +| model_name_or_path | Rouge-1 | Rouge-2 | Rouge-L | |
| 188 | +| :----------------------: | :-------------: | :-------------: |:-------------: | |
| 189 | +| [bart-base-cnndm-model](https://paddlenlp.bj.bcebos.com/models/transformers/bart/bart-base-cnndm-model.tar.gz ) | 43.6446 | 20.1447 | 41.0132 | |
| 190 | + |
| 191 | +**NOTE:** `bart-base-cnndm-model`是按本项目中的超参finetune得到的结果。 |
| 192 | + |
| 193 | +### 模型高性能预测 |
| 194 | + |
| 195 | +在模型预测阶段,我们提供了基于 FasterTransformer 的高性能预测的选项,可以选择性开启是否需要采用高性能预测。只需在上述模型预测上添加三个参数即可:分别是`faster`,`use_fp16_decoding`,`decoding_lib`。 |
| 196 | + |
| 197 | +```shell |
| 198 | +# GPU启动,预测仅支持单卡 |
| 199 | +export CUDA_VISIBLE_DEVICES=0 |
| 200 | +python generate.py \ |
| 201 | + --model_name_or_path=bart-base-cnndm-model \ |
| 202 | + --dataset_name=cnn_dailymail \ |
| 203 | + --output_path=generate.txt \ |
| 204 | + --max_source_length=1024 \ |
| 205 | + --max_target_length=142 \ |
| 206 | + --decode_strategy=greedy_search \ |
| 207 | + --top_k=2 \ |
| 208 | + --top_p=1.0 \ |
| 209 | + --num_beams=1 \ |
| 210 | + --length_penalty=0.0 \ |
| 211 | + --batch_size=64 \ |
| 212 | + --seed=42 \ |
| 213 | + --ignore_pad_token_for_loss=True \ |
| 214 | + --logging_steps=100 \ |
| 215 | + --faster \ |
| 216 | + --use_fp16_decoding \ |
| 217 | + --decoding_lib=../../../paddlenlp/ops/build/lib/libdecoding_op.so \ |
| 218 | + --device=gpu |
| 219 | +``` |
| 220 | +其中新增参数释义如下: |
| 221 | +- `faster` 表示是否开启高性能预测。设置 `--faster` 即表示开启。 |
| 222 | +- `use_fp16_decoding` 表示在开启高性能预测的时候,是否使用 fp16 来完成预测过程。设置 `--use_fp16_decoding` 即表示使用 fp16 进行预测,否则使用 fp32。 |
| 223 | +- `decoding_lib` 如果不存在,将使用 JIT 自动编译所需的动态库。如果需要自行编译,可参考[自定义OP编译使用](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/ops/README.md#%E7%BC%96%E8%AF%91%E8%87%AA%E5%AE%9A%E4%B9%89op) ,然后设定为编译出的高性能自定义 OP的动态库的位置即可。 |
| 224 | + |
| 225 | +## 参考文献 |
| 226 | +1. Lewis M , Liu Y , Goyal N , et al. [BART: Denoising Sequence-to-Sequence Pre-training for Natural |
| 227 | +Language Generation, Translation, and Comprehension](https://aclanthology.org/2020.acl-main.703.pdf)[C]//Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics. 2020: 7871-7880. |
| 228 | +2. See A , Liu P J , CD Manning. [Get To The Point: Summarization with Pointer-Generator Networks](https://aclanthology.org/P17-1099.pdf)[C]// Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics. 2017: 1073–1083. |
0 commit comments