Skip to content

Commit 4b870be

Browse files
Update readme and fix expand for FasterGPT and FasterMbart (#1436)
* add attn_mask input for encoder-decoder * update readme and fix expand for FasterGPT and FasterMbart * add image * fix image size * fix image * fix expand for attention mask * remove paddlenlp perf * fix doc * remove attention mask in expand Co-authored-by: Guo Sheng <[email protected]>
1 parent 6cac0ac commit 4b870be

File tree

6 files changed

+162
-47
lines changed

6 files changed

+162
-47
lines changed

docs/imgs/bart_perf.png

134 KB
Loading

docs/imgs/gpt_perf.png

134 KB
Loading

examples/faster/faster_generation/README.md

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
FasterGeneration是PaddleNLP v2.2版本加入的一个高性能推理功能,可实现基于CUDA的序列解码。该功能可以用于多种生成类的预训练NLP模型,例如GPT、BART、UnifiedTransformer等,并且支持多种解码策略。因此该功能主要适用于机器翻译,文本续写,文本摘要,对话生成等任务。
44

5-
功能底层依托于[FasterTransformer](https://github.com/NVIDIA/FasterTransformer),该库专门针对Transformer系列模型及各种解码策略进行了优化。功能顶层封装于`model.generate`函数。功能的开启和关闭通过传入`use_faster`参数进行控制(默认为开启状态)。通过调用generate函数,用户可以简单实现模型的高性能推理功能。下图展示了FasterGeneration的启动流程:
5+
功能底层依托于[FasterTransformer](https://github.com/NVIDIA/FasterTransformer),该库专门针对Transformer系列模型及各种解码策略进行了优化。功能顶层封装于`model.generate`函数。功能的开启和关闭通过传入`use_faster`参数进行控制(默认为关闭状态)。通过调用generate函数,用户可以简单实现模型的高性能推理功能。下图展示了FasterGeneration的启动流程:
66

77

88
<p align="center">
@@ -13,7 +13,7 @@ FasterGeneration是PaddleNLP v2.2版本加入的一个高性能推理功能,
1313

1414
- 全面支持生成式预训练模型。包括GPT、BART、mBART、UnifiedTransformer和UNIMO-text。
1515
- 支持大多数主流解码策略。包括Beam Search、Sampling、Greedy Search。以及Diverse Sibling Search、Length Penalty等子策略。
16-
- 解码速度快。最高可达非加速版generate函数的**10倍**。HuggingFace generate函数的**5倍****并支持FP16混合精度计算**
16+
- 解码速度快。最高可达非加速版generate函数的**18倍****并支持FP16混合精度计算**
1717
- 易用性强。功能的入口为`model.generate`,与非加速版生成api的使用方法相同,当满足加速条件时使用jit即时编译高性能算子并用于生成,不满足则自动切换回非加速版生成api。
1818

1919
### Inference Model Support
@@ -34,12 +34,28 @@ FasterGeneration是PaddleNLP v2.2版本加入的一个高性能推理功能,
3434

3535
## Performence
3636

37-
FasterGeneration的高性能解码相比原版generate方法加速明显,并且与竞品相比有也有极大的速度优势。测试设备为Tesla V100-SXM2-16GB,精度为FP32。
37+
FasterGeneration的高性能解码相比原版generate方法加速明显,并且与竞品相比有也有极大的速度优势。以下为性能对比图:
3838

39-
- **BART** (bart-base, batch_size=4, max_length=32) 图片
40-
- **GPT** (gpt2, batch_size=4, max_length=32) 图片
39+
- **batch_size = 4, out_seq_len = 32**
40+
- Device: Tesla V100-SXM2-16GB
41+
- CUDA version 11.2
42+
- cudnn version 8
43+
- torch version 1.10.0+cu113
44+
- transformers version 4.12.5
4145

42-
更详细的性能数据请参见[这里](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/experimental/faster_generation/perf)
46+
**BART** (bart-base, batch_size=4, max_length=32)
47+
48+
<p align="left">
49+
<img src="../../../docs/imgs/bart_perf.png" width="800" height ="400" />
50+
</p>
51+
52+
**GPT** (gpt2, batch_size=4, max_length=32)
53+
54+
<p align="left">
55+
<img src="../../../docs/imgs/gpt_perf.png" width="800" height ="400" />
56+
</p>
57+
58+
更详细的性能数据请参见[这里](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/faster/faster_generation/perf)
4359

4460
## Quick Start
4561

@@ -85,7 +101,8 @@ Result: 对影成三人。
85101
model = GPTLMHeadModel.from_pretrained(model_name)
86102
...
87103
outputs, _ = model.generate(
88-
input_ids=inputs_ids, max_length=10, decode_strategy='greedy_search')
104+
input_ids=inputs_ids, max_length=10, decode_strategy='greedy_search',
105+
use_faster=True)
89106
...
90107
```
91108

@@ -98,7 +115,7 @@ outputs, _ = model.generate(
98115
...
99116
```
100117

101-
**NOTE:** 需要注意的是,如果传入 `model.generate()` 的参数不满足高性能版本的要求。程序会做出提示并自动切换为非加速版本,例如我们传入 `min_length=1` ,会得到如下提示:
118+
**NOTE:** 需要注意的是,如果传入 `model.generate()` 的参数不满足高性能版本的要求。程序会做出提示并自动切换为非加速版本,例如我们在上面的例子中传入 `min_length=1` ,会得到如下提示:
102119

103120
```
104121
...
@@ -137,6 +154,7 @@ export CUDA_VISIBLE_DEVICES=0
137154
--num_return_sequences=1 \
138155
--decode_strategy=sampling \
139156
--top_k=5 \
157+
--faster
140158
--device=gpu
141159
```
142160

@@ -177,7 +195,7 @@ step 30 - 1.435s/step
177195

178196
可以看到,非加速版 `generate()` 方法的预测速度为每个step耗时1.5秒左右。
179197

180-
下面我们在启动脚本中传入 `--faster` 参数,这会让 `generate()` 方法传入 `use_faster=True` ,启动加速模式。同时我们需要设置 `--min_dec_len=0` ,因为FasterGeneration当前还不支持该参数。新的脚本启动参数如下:
198+
下面我们在启动脚本中传入 `--faster` 参数,该参数会向 `generate()` 方法传入 `use_faster=True` ,启动加速模式。同时我们需要设置 `--min_dec_len=0` ,因为FasterGeneration当前还不支持该参数。新的脚本启动参数如下:
181199

182200
```sh
183201
export CUDA_VISIBLE_DEVICES=0
@@ -202,9 +220,9 @@ export CUDA_VISIBLE_DEVICES=0
202220

203221
```sh
204222
[2021-11-23 13:38:09,200] [ DEBUG] - skipping 'FasterTransformer' extension (up-to-date) build
205-
step 10 - 0.511s/step
206-
step 20 - 0.343s/step
207-
step 30 - 0.419s/step
223+
step 10 - 0.250s/step
224+
step 20 - 0.156s/step
225+
step 30 - 0.141s/step
208226
```
209227

210-
可以看到,FasterGeneration的预测速度为每个step耗时0.4秒左右,提速超过三倍。如果减少 `num_return_sequences` ,可以得到更高的加速比
228+
可以看到,FasterGeneration的预测速度为每个step耗时0.15秒左右,相比非加速版提速超过9倍

examples/faster/faster_generation/perf/README.md

Lines changed: 116 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,40 +5,135 @@
55
- **测试设备:** Tesla V100-SXM2-16GB
66
- **Batch Size:** 4
77
- **Max Length:** 32
8-
- **精度:** FP16
98

10-
表格
9+
## 性能数据
10+
***
11+
12+
CUDA 10.1, cudnn 7, gcc 82
13+
14+
torch version 1.10.0+cu102, transformers version 4.12.5
15+
16+
**BART:**
17+
18+
| Model Size | Decode Strategy| FasterGeneration(FP32)<br>(ms) | FasterGeneration(FP16)<br>(ms) | HF generate<br>(ms) | Speed Up Rate<br>(Faster32/HF) | Speed Up Rate<br>(Faster16/HF) |
19+
|-----|----|---|---|---|---|---|
20+
|num_layers = 6<br>num_attention_heads = 12<br>hidden_size = 768<br>(bart-base)|top_k = 1|37.53|34.01|136.89|3.65|4.02
21+
| |top_k = 4 |39.33|34.98|146.89|3.73|4.2 |
22+
| |top_k = 8 |42.35|34.77|136.80|3.23|3.93|
23+
| |top_k = 16 |40.95|35.45|148.45|3.63|4.19|
24+
| |top_p = 0.4 |45.83|33.32|184.36|4.02|5.53|
25+
| |num_beams = 4|44.72|37.51|242.73|5.43|6.47|
26+
| |num_beams = 8|61.56|40.27|273.93|4.45|6.8 |
27+
| |num_beams = 16|82.05|46.68|433.51|5.28|9.29|
28+
|num_layers = 12<br>num_attention_heads = 16<br>hidden_size = 1024<br>(bart-large)|top_k = 1|55.03|45.44|199.27|3.62|4.39|
29+
| |top_k = 4|70.12|56.81|220.96|3.15|3.89|
30+
| |top_k = 8|69.96|57.73|201.06|2.87|3.48|
31+
| |top_k = 16|69.16|59.62|223.73|3.23|3.75|
32+
| |top_p = 0.4|73.49|61.43|275.86|3.75|4.49|
33+
| |num_beams = 4|66.44|50.71|277.61|4.18|5.47|
34+
| |num_beams = 8|135.30|85.75|314.78|2.33|3.67|
35+
| |num_beams = 16|168.01|100.22|441.95|2.63|4.41|
36+
37+
**GPT:**
38+
39+
| Model Size | Decode Strategy| FasterGeneration(FP32)<br>(ms) | FasterGeneration(FP16)<br>(ms) | HF generate<br>(ms) | Speed Up Rate<br>(Faster32/HF) | Speed Up Rate<br>(Faster16/HF) |
40+
|-----|----|---|---|---|---|---|
41+
|num_layers = 12<br>num_attention_heads = 12<br>hidden_size = 768<br>(gpt2)|top_k = 1|69.29|59.20|363.93|5.25|6.15|
42+
| |top_k = 4|68.07|60.92|391.02|5.74|6.42|
43+
| |top_k = 8|69.16|60.45|401.18|5.80|6.64|
44+
| |top_k = 16|73.59|62.40|401.55|5.46|6.44|
45+
| |top_p = 0.4|95.61|76.26|429.63|4.49|5.63|
46+
|num_layers = 24<br>num_attention_heads = 16<br>hidden_size = 1024<br>(gpt2-medium)|top_k = 1|127.04|95.13|726.83|5.72|7.64|
47+
| |top_k = 4|126.74|93.95|694.53|5.48|7.39|
48+
| |top_k = 8|128.11|94.07|743.63|5.80|7.91|
49+
| |top_k = 16|126.78|95.00|732.96|5.78|7.72|
50+
| |top_p = 0.4|143.36|105.40|756.12|5.27|7.17|
51+
|num_layers = 36<br>num_attention_heads = 20<br>hidden_size = 1280<br>(gpt2-large)top_k = 1|236.80|200.37|1057.94|4.47|5.28|
52+
| |top_k = 4|236.69|201.95|1075.17|4.54|5.32|
53+
| |top_k = 8|237.04|202.00|1084.60|4.58|5.37|
54+
| |top_k = 16|235.01|201.79|1110.75|4.73|5.5|
55+
| |top_p = 0.4|270.31|205.84|1111.16|4.11|5.4|
56+
57+
***
58+
59+
CUDA 11.2, cudnn 8, gcc 82
60+
61+
torch version 1.10.0+cu113, transformers version 4.12.5
62+
63+
**BART:**
64+
65+
| Model Size | Decode Strategy| FasterGeneration(FP32)<br>(ms) | FasterGeneration(FP16)<br>(ms) | HF generate<br>(ms) | Speed Up Rate<br>(Faster32/HF) | Speed Up Rate<br>(Faster16/HF) |
66+
|-----|----|---|---|---|---|---|
67+
|num_layers = 6<br>num_attention_heads = 12<br>hidden_size = 768<br>(bart-base)|top_k = 1|30.08|27.95|166.90|5.55|5.97
68+
| |top_k = 4 |30.82|30.01|184.58|5.99|6.15 |
69+
| |top_k = 8 |32.06|31.05|183.44|5.72|5.91|
70+
| |top_k = 16 |32.66|32.35|187.14|5.73|5.78|
71+
| |top_p = 0.4 |37.99|30.25|208.33|5.48|6.89|
72+
| |num_beams = 4|45.99|37.51|285.01|5.43|7.6|
73+
| |num_beams = 8|50.12|37.82|316.56|6.32|8.37|
74+
| |num_beams = 16|67.66|40.98|467.76|6.91|11.41|
75+
|num_layers = 12<br>num_attention_heads = 16<br>hidden_size = 1024<br>(bart-large)|top_k = 1|50.23|39.08|222.59|4.43|5.7|
76+
| |top_k = 4|60.59|48.32|307.76|5.08|6.37|
77+
| |top_k = 8|59.67|49.65|310.49|5.20|6.25|
78+
| |top_k = 16|59.15|52.68|333.75|5.64|6.34|
79+
| |top_p = 0.4|61.36|50.83|340.74|5.55|6.7|
80+
| |num_beams = 4|65.60|53.24|336.28|5.12|6.32|
81+
| |num_beams = 8|76.20|54.13|396.62|5.20|7.33|
82+
| |num_beams = 16|102.04|61.11|531.92|5.21|8.7|
83+
84+
**GPT:**
85+
86+
| Model Size | Decode Strategy| FasterGeneration(FP32)<br>(ms) | FasterGeneration(FP16)<br>(ms) | HF generate<br>(ms) | Speed Up Rate<br>(Faster32/HF) | Speed Up Rate<br>(Faster16/HF) |
87+
|-----|----|---|---|---|---|---|
88+
|num_layers = 12<br>num_attention_heads = 12<br>hidden_size = 768<br>(gpt2)|top_k = 1|49.75|40.15|483.02|9.71|12.03|
89+
| |top_k = 4|49.70|41.69|496.63|9.99|11.91|
90+
| |top_k = 8|51.81|40.81|485.77|9.38|11.9|
91+
| |top_k = 16|50.36|42.88|488.38|9.70|11.39|
92+
| |top_p = 0.4|68.30|53.58|544.53|7.97|10.16|
93+
|num_layers = 24<br>num_attention_heads = 16<br>hidden_size = 1024<br>(gpt2-medium)|top_k = 1|109.86|76.88|936.02|8.52|12.18|
94+
| |top_k = 4|109.69|78.70|943.71|8.60|11.99|
95+
| |top_k = 8|109.70|78.39|963.73|8.79|12.29|
96+
| |top_k = 16|111.18|79.05|945.27|8.50|11.96|
97+
| |top_p = 0.4|127.54|89.76|999.28|7.83|11.13|
98+
|num_layers = 36<br>num_attention_heads = 20<br>hidden_size = 1280<br>(gpt2-large)|top_k = 1|205.92|142.85|1368.78|6.65|9.58|
99+
| |top_k = 4|205.43|140.40|1374.83|6.69|9.79|
100+
| |top_k = 8|205.62|139.47|1406.42|6.84|10.08|
101+
| |top_k = 16|205.16|139.77|1392.37|6.79|9.96|
102+
| |top_p = 0.4|221.06|152.35|1452.07|6.57|9.53|
103+
11104

12105
## 测试方法
13106

14107
运行如下命令即可bart性能测试:
15108

16109
```sh
17-
python bart_perf.py \
18-
--model_name_or_path=bart-base \
19-
--decode_strategy=sampling \
20-
--num_beams=4 \
21-
--top_k=16 \
22-
--top_p=1.0 \
23-
--max_length=32 \
110+
bash run_perf_bart.sh
24111
```
25112

26113
运行如下命令即可启动gpt性能测试:
27114

28115
```sh
29-
python gpt_perf.py \
30-
--model_name_or_path=gpt2-en \
31-
--decode_strategy=sampling \
32-
--top_k=1 \
33-
--top_p=1.0 \
34-
--max_length=32 \
116+
bash run_perf_gpt.sh
35117
```
36118

37-
其中参数释义如下:
38-
- `model_name_or_path` 指定测试使用的模型参数。其中bart可以在`bart-base``bart-large`中选择,gpt可以在`gpt2-en``gpt2-medium-en``gpt2-large-en`中选择。
39-
- `decode_strategy` 表示预测解码时采取的策略,可选"sampling"、"greedy_search"和"beam_search"之一。**注意GPT当前不支持beam_search**
40-
- `top_k` 表示采用"sampling"解码策略时,token的概率按从大到小排序,生成的token只从前`top_k`个中进行采样。
41-
- `top_p` 表示采用"sampling"解码策略时,token的概率按从大到小排序,生成的token只从概率累加到`top_p`的前某几个中进行采样。
42-
- `max_length` 表示预测生成的句子的最大长度。
119+
运行以上命令后,脚本会自动使用不同的模型参数进行性能测试,结果如下图所示:
120+
121+
```sh
122+
...
123+
[2021-12-10 08:11:37,255] [ DEBUG] - skipping 'FasterTransformer' extension (up-to-date) build
124+
Namespace(decode_strategy='sampling', max_length=32, model_name_or_path='bart-base', num_beams=1, top_k=1, top_p=1.0, use_fp16_decoding=False)
125+
Faster FP32 cost: 40.13654176145792
126+
PD cost: 511.413540635258
127+
HF cost: 138.49875444546342
128+
Speed up Faster FP32/PD: 12.741843671403577
129+
Speed up Faster FP32/HF: 3.4506897796177394
130+
...
131+
...
132+
[2021-12-10 08:13:42,858] [ DEBUG] - skipping 'FasterTransformer' extension (up-to-date) build
133+
Namespace(decode_strategy='sampling', max_length=32, model_name_or_path='bart-base', num_beams=1, top_k=1, top_p=1.0, use_fp16_decoding=True)
134+
Faster FP16 cost: 34.004870522767305
135+
...
136+
```
137+
可以看到,对于每组参数,脚本会先输出FP32和竞品的测试对比,再单独输出FP16的性能数据。
43138

44139
**NOTE:** 根据测试环境和机器状态的不同,以上性能测试脚本的结果可能与表中结果有所出入。

paddlenlp/ops/README.md

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,6 @@ model = model_class.from_pretrained(args.model_name)
172172
# Define model
173173
gpt = FasterGPT(
174174
model=model,
175-
topk=args.topk,
176-
topp=args.topp,
177-
max_out_len=args.max_out_len,
178-
bos_id=bos_id,
179-
eos_id=eos_id,
180-
temperature=args.temperature,
181175
decoding_lib=args.decoding_lib,
182176
use_fp16_decoding=args.use_fp16_decoding)
183177
```
@@ -194,7 +188,7 @@ gpt = FasterGPT(
194188

195189
``` sh
196190
export CUDA_VISIBLE_DEVICES=0
197-
python ./faster_transformer/sample/gpt_sample.py --model_name_or_path gpt2-medium-en --batch_size 1 --topk 4 --topp 0.0 --max_out_len 32 --start_token "<|endoftext|>" --end_token "<|endoftext|>" --temperature 1.0
191+
python ./faster_transformer/sample/gpt_sample.py --model_name_or_path gpt2-medium-en --batch_size 1 --topk 4 --topp 0.0 --max_length 32 --start_token "<|endoftext|>" --end_token "<|endoftext|>" --temperature 1.0
198192
```
199193

200194
其中,各个选项的意义如下:
@@ -203,7 +197,7 @@ python ./faster_transformer/sample/gpt_sample.py --model_name_or_path gpt2-mediu
203197
* `--batch_size`: 一个 batch 内,样本数目的大小。
204198
* `--topk`: 执行 topk-sampling 的时候的 `k` 的大小,默认是 4。
205199
* `--topp`: 执行 topp-sampling 的时候的阈值的大小,默认是 0.0 表示不执行 topp-sampling。
206-
* `--max_out_len`: 最长的生成长度。
200+
* `--max_length`: 最长的生成长度。
207201
* `--start_token`: 字符串,表示任意生成的时候的开始 token。
208202
* `--end_token`: 字符串,生成的结束 token。
209203
* `--temperature`: temperature 的设定。

paddlenlp/ops/faster_transformer/transformer/faster_transformer.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -681,8 +681,12 @@ def forward(self,
681681

682682
if num_return_sequences > 1:
683683
input_ids, model_kwargs = self.expand_inputs_for_generation(
684-
input_ids, expand_size=num_return_sequences, seq_len=seq_len)
684+
input_ids,
685+
expand_size=num_return_sequences,
686+
seq_len=seq_len,
687+
attention_mask=attention_mask)
685688
seq_len = model_kwargs["seq_len"]
689+
attention_mask = model_kwargs.get("attention_mask", None)
686690

687691
return self.decoding(
688692
input_ids,
@@ -1203,7 +1207,6 @@ def forward(self,
12031207
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else getattr(
12041208
self._model, 'decoder_start_token_id', None)
12051209

1206-
#(gongenlei) Not enable_faster_encoder temporarily
12071210
self.encoder = enable_faster_encoder(self.encoder, need_build=False)
12081211
if encoder_output is None:
12091212
assert input_ids is not None, "You have to specify either input_ids or encoder_output."
@@ -1233,10 +1236,15 @@ def forward(self,
12331236
if decoder_start_token_id is not None:
12341237
bos_token_id = decoder_start_token_id
12351238

1236-
# TODO(gongenlei) Need to expand
12371239
if forced_bos_token_id is not None:
1238-
trg_word = paddle.full(
1239-
[batch_size, 1], forced_bos_token_id, dtype="int32")
1240+
if decode_strategy == "sampling":
1241+
trg_word = paddle.full(
1242+
[batch_size * num_return_sequences, 1],
1243+
forced_bos_token_id,
1244+
dtype="int32")
1245+
else:
1246+
trg_word = paddle.full(
1247+
[batch_size, 1], forced_bos_token_id, dtype="int32")
12401248
else:
12411249
trg_word = paddle.zeros([0])
12421250

0 commit comments

Comments
 (0)