Skip to content

Commit dbd26fe

Browse files
authored
Add PLATO-XL example (#1708)
1 parent 4c36ef9 commit dbd26fe

File tree

4 files changed

+338
-26
lines changed

4 files changed

+338
-26
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# PLATO-XL
2+
3+
## 模型简介
4+
5+
构建高质量的开放领域(Open-Domain)的对话机器人,使得它能用自然语言与人自由地交流,这一直是自然语言处理领域终极目标之一。
6+
7+
为了能够简易地构建一个高质量的开放域聊天机器人,本项目在 Paddle 上实现了 PLATO-XL 的预测模型,并实现了高性能的预测加速,而整套 float16 的方案可以确保在 32G V100 单卡上就能 load 并执行 11B 的 PLATO-XL 模型,无需再涉及 float32 相关计算。
8+
9+
此外,PLATO-XL 72-layers, 32-heads, 3072-hidden,网络参数量较大,即使是在使用 float16 的情况下,72 层网络至少需要显存约 24G,并且需要保证当前使用的 GPU 支持 float16 的计算。
10+
11+
其中:
12+
* 支持 float16 的 GPU 信息可以在 NVIDIA [官网](https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html#hardware-precision-matrix) 上查询;
13+
* 您当前使用的 GPU 的 compute capability 同样可以在 NVIDIA [官网](https://developer.nvidia.com/zh-cn/cuda-gpus#compute) 上找到,与上面链接中表格对应。
14+
15+
PLATO-XL 的训练过程及其他细节详见 [Knover](https://github.com/PaddlePaddle/Knover/tree/develop/projects/PLATO-XL)
16+
17+
## 快速开始
18+
19+
### 环境依赖
20+
21+
- python 3.7+
22+
- sentencepiece
23+
24+
安装方式:
25+
``` python
26+
pip install sentencepiece
27+
```
28+
29+
### 高性能生成
30+
31+
使用 `infer.py` 脚本进行测试,无需单独下载预训练,脚本将自行下载。运行如下命令即可进行高性能预测,forward 将自动循环 200 次前向以供性能测试需要。
32+
33+
```shell
34+
export CUDA_VISIBLE_DEVICES=0
35+
python infer.py --use_role --position_style relative --max_out_len 64 --min_out_len 1 --topk 4
36+
```
37+
38+
该脚本各个参数含义如下:
39+
40+
* `--use_role`: 是否使用 role embedding。
41+
* `--position_style`: 位置编码方式,这里可以选择是 "relative" 或是 "continuous"。
42+
* `--max_out_len`: 最长的输出的长度。
43+
* `--min_out_len`: 最短的输出长度。
44+
* `--topk`: 用于 top_k sampling 的 k 值的设定。
45+
* `--topp`: 用于 top_p sampling 的 p 值的设定。
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import time
16+
import argparse
17+
from pprint import pprint
18+
19+
import paddle
20+
21+
from paddlenlp.transformers import UnifiedTransformerModel, UnifiedTransformerLMHeadModel, UnifiedTransformerTokenizer
22+
23+
24+
def setup_args():
25+
"""Setup arguments."""
26+
parser = argparse.ArgumentParser()
27+
parser.add_argument(
28+
"--use_role",
29+
action="store_true",
30+
help="Whether to use role embeddings. ")
31+
parser.add_argument(
32+
"--position_style",
33+
default="relative",
34+
choices=["continuous", "relative"],
35+
type=str,
36+
help="The type for positional embedding. Default is continuous. ")
37+
parser.add_argument(
38+
"--max_out_len",
39+
default=64,
40+
type=int,
41+
help="Maximum output sequence length. ")
42+
parser.add_argument(
43+
"--min_out_len",
44+
default=1,
45+
type=int,
46+
help="Minimum output sequence length. ")
47+
parser.add_argument(
48+
"--topk",
49+
default=4,
50+
type=int,
51+
help="The k value for topk_sampling. Default is 4. ")
52+
parser.add_argument(
53+
"--topp",
54+
default=1.0,
55+
type=float,
56+
help="The p value for topp_sampling. Default is 0.0f. ")
57+
parser.add_argument(
58+
"--use_fp16_decoding",
59+
action="store_true",
60+
help="Whether to use fp16 decoding to predict. ")
61+
parser.add_argument(
62+
"--decoding_strategy",
63+
default="sampling",
64+
choices=["sampling", "beam_search"],
65+
type=str,
66+
help="The main strategy to decode. ")
67+
parser.add_argument(
68+
"--num_beams",
69+
default=4,
70+
type=int,
71+
help="The number of candidate to procedure beam search. ")
72+
73+
args = parser.parse_args()
74+
75+
return args
76+
77+
78+
def postprocess_response(token_ids, tokenizer):
79+
"""Post-process the decoded sequence. Truncate from the first <eos>."""
80+
eos_pos = len(token_ids)
81+
for i, tok_id in enumerate(token_ids):
82+
if tok_id == tokenizer.sep_token_id:
83+
eos_pos = i
84+
break
85+
token_ids = token_ids[:eos_pos]
86+
tokens = tokenizer.convert_ids_to_tokens(token_ids)
87+
tokens = tokenizer.merge_subword(tokens)
88+
return tokens
89+
90+
91+
def infer(args):
92+
model_name = 'plato-xl'
93+
model = UnifiedTransformerLMHeadModel.from_pretrained(model_name)
94+
tokenizer = UnifiedTransformerTokenizer.from_pretrained(model_name)
95+
96+
context = [
97+
"Hi , Becky , what's up ?",
98+
"Not much , except that my mother-in-law is driving me up the wall .",
99+
"What's the problem ?"
100+
]
101+
102+
data = tokenizer.dialogue_encode(
103+
history=context,
104+
add_start_token_as_response=True,
105+
return_length=True,
106+
return_role_ids=args.use_role,
107+
position_style=args.position_style)
108+
109+
for name in data:
110+
if name == "attention_mask":
111+
data[name] = paddle.to_tensor(
112+
data[name], dtype="float32").reshape([1, 1, 41, 41])
113+
else:
114+
data[name] = paddle.to_tensor(
115+
data[name], dtype="int32").reshape([1, -1])
116+
117+
for i in range(200):
118+
if 100 == i:
119+
paddle.device.cuda.synchronize()
120+
start = time.time()
121+
122+
outputs, _ = model.generate(
123+
input_ids=data['input_ids'],
124+
token_type_ids=data['token_type_ids'],
125+
position_ids=data['position_ids'],
126+
attention_mask=data['attention_mask'],
127+
role_ids=data.get('role_ids', None),
128+
seq_len=data['seq_len'],
129+
max_length=args.max_out_len,
130+
min_length=args.min_out_len,
131+
decode_strategy=args.decoding_strategy,
132+
top_k=args.topk,
133+
top_p=args.topp,
134+
num_beams=args.num_beams,
135+
use_fp16_decoding=args.use_fp16_decoding,
136+
use_faster=True)
137+
138+
paddle.device.cuda.synchronize()
139+
print("Average time of FasterGeneration of PLATO-XL model is {}ms. ".format(
140+
(time.time() - start) / 100 * 1000))
141+
142+
result = postprocess_response(outputs[0].numpy(), tokenizer)
143+
result = " ".join(result)
144+
145+
print("Model input:", context)
146+
print("Result:", result)
147+
148+
149+
if __name__ == "__main__":
150+
args = setup_args()
151+
pprint(args)
152+
153+
infer(args)

paddlenlp/transformers/unified_transformer/modeling.py

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,26 @@ class UnifiedTransformerPretrainedModel(PretrainedModel):
9696
"eos_token_id": 2,
9797
"mask_token_id": 30000,
9898
},
99+
"plato-xl": {
100+
"vocab_size": 8001,
101+
"hidden_size": 3072,
102+
"num_hidden_layers": 72,
103+
"num_attention_heads": 32,
104+
"intermediate_size": 18432,
105+
"hidden_act": "gelu",
106+
"hidden_dropout_prob": 0.1,
107+
"attention_probs_dropout_prob": 0.1,
108+
"normalize_before": True,
109+
"max_position_embeddings": 1024,
110+
"type_vocab_size": 2,
111+
"role_type_size": 128,
112+
"initializer_range": 0.02,
113+
"unk_token_id": 0,
114+
"pad_token_id": 0,
115+
"bos_token_id": 1,
116+
"eos_token_id": 2,
117+
"mask_token_id": 8000,
118+
}
99119
}
100120
resource_files_names = {"model_state": "model_state.pdparams"}
101121
pretrained_resource_files_map = {
@@ -106,6 +126,8 @@ class UnifiedTransformerPretrainedModel(PretrainedModel):
106126
"https://bj.bcebos.com/paddlenlp/models/transformers/unified_transformer/unified_transformer-12L-cn-luge.pdparams",
107127
"plato-mini":
108128
"https://bj.bcebos.com/paddlenlp/models/transformers/unified_transformer/plato-mini.pdparams",
129+
"plato-xl":
130+
"https://bj.bcebos.com/paddlenlp/models/transformers/unified_transformer/plato-xl.pdparams",
109131
}
110132
}
111133
base_model_prefix = "unified_transformer"
@@ -115,7 +137,9 @@ def init_weights(self, layer):
115137
if isinstance(layer, (nn.Linear, nn.Embedding)):
116138
# In the dygraph mode, use the `set_value` to reset the parameter directly,
117139
# and reset the `state_dict` to update parameter in static mode.
118-
if isinstance(layer.weight, paddle.Tensor):
140+
if isinstance(
141+
layer.weight,
142+
paddle.Tensor) and paddle.get_default_dtype() == "float32":
119143
layer.weight.set_value(
120144
paddle.tensor.normal(
121145
mean=0.0,
@@ -133,20 +157,27 @@ def __init__(self,
133157
hidden_size=768,
134158
hidden_dropout_prob=0.1,
135159
max_position_embeddings=512,
136-
type_vocab_size=2):
160+
type_vocab_size=2,
161+
role_type_size=None):
137162
super(UnifiedTransformerEmbeddings, self).__init__()
138163
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
139164
self.position_embeddings = nn.Embedding(max_position_embeddings,
140165
hidden_size)
141166
self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size)
167+
self.role_embeddings = None if role_type_size is None else nn.Embedding(
168+
role_type_size, hidden_size)
142169
self.dropout = nn.Dropout(hidden_dropout_prob)
143170

144-
def forward(self, input_ids, token_type_ids, position_ids):
171+
def forward(self, input_ids, token_type_ids, position_ids, role_ids=None):
145172
input_embedings = self.word_embeddings(input_ids)
146173
position_embeddings = self.position_embeddings(position_ids)
147174
token_type_embeddings = self.token_type_embeddings(token_type_ids)
148175

149176
embeddings = input_embedings + position_embeddings + token_type_embeddings
177+
178+
if self.role_embeddings is not None:
179+
embeddings += self.role_embeddings(role_ids)
180+
150181
embeddings = self.dropout(embeddings)
151182
return embeddings
152183

@@ -221,25 +252,25 @@ class UnifiedTransformerModel(UnifiedTransformerPretrainedModel):
221252
The id of special token `mask_token`. Defaults to 30000.
222253
"""
223254

224-
def __init__(
225-
self,
226-
vocab_size,
227-
hidden_size=768,
228-
num_hidden_layers=12,
229-
num_attention_heads=12,
230-
intermediate_size=3072,
231-
hidden_act="gelu",
232-
hidden_dropout_prob=0.1,
233-
attention_probs_dropout_prob=0.1,
234-
normalize_before=True,
235-
max_position_embeddings=512,
236-
type_vocab_size=2,
237-
initializer_range=0.02,
238-
unk_token_id=0,
239-
pad_token_id=0,
240-
bos_token_id=1,
241-
eos_token_id=2,
242-
mask_token_id=30000, ):
255+
def __init__(self,
256+
vocab_size,
257+
hidden_size=768,
258+
num_hidden_layers=12,
259+
num_attention_heads=12,
260+
intermediate_size=3072,
261+
hidden_act="gelu",
262+
hidden_dropout_prob=0.1,
263+
attention_probs_dropout_prob=0.1,
264+
normalize_before=True,
265+
max_position_embeddings=512,
266+
type_vocab_size=2,
267+
initializer_range=0.02,
268+
unk_token_id=0,
269+
pad_token_id=0,
270+
bos_token_id=1,
271+
eos_token_id=2,
272+
mask_token_id=30000,
273+
role_type_size=None):
243274
super(UnifiedTransformerModel, self).__init__()
244275
self.unk_token_id = unk_token_id
245276
self.pad_token_id = pad_token_id
@@ -250,7 +281,7 @@ def __init__(
250281

251282
self.embeddings = UnifiedTransformerEmbeddings(
252283
vocab_size, hidden_size, hidden_dropout_prob,
253-
max_position_embeddings, type_vocab_size)
284+
max_position_embeddings, type_vocab_size, role_type_size)
254285
encoder_layer = nn.TransformerEncoderLayer(
255286
hidden_size,
256287
num_attention_heads,

0 commit comments

Comments
 (0)