Skip to content

Commit 4c36ef9

Browse files
authored
FasterUnifiedTransformer/PLATO support dy2sta (#1717)
* support ut dy2sta * use jit load
1 parent 4a91065 commit 4c36ef9

File tree

6 files changed

+372
-63
lines changed

6 files changed

+372
-63
lines changed
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sys
16+
import os
17+
import numpy as np
18+
from attrdict import AttrDict
19+
import argparse
20+
import time
21+
22+
import paddle
23+
24+
import yaml
25+
from pprint import pprint
26+
27+
from paddlenlp.ops import FasterGPT
28+
from paddlenlp.transformers import UnifiedTransformerModel, UnifiedTransformerLMHeadModel
29+
from paddlenlp.ops import FasterUnifiedTransformer
30+
31+
from paddlenlp.utils.log import logger
32+
33+
34+
def parse_args():
35+
parser = argparse.ArgumentParser()
36+
parser.add_argument(
37+
"--model_name_or_path",
38+
default="plato-xl",
39+
type=str,
40+
help="The model name to specify the gpt to use. Can be one of ['gpt2-en', 'gpt2-medium-en', 'gpt-cpm-large-cn']. "
41+
)
42+
parser.add_argument(
43+
"--inference_model_dir",
44+
default="./infer_model/",
45+
type=str,
46+
help="Path to save inference model of gpt. ")
47+
parser.add_argument(
48+
"--topk",
49+
default=4,
50+
type=int,
51+
help="The number of candidate to procedure top_k sampling. ")
52+
parser.add_argument(
53+
"--topp",
54+
default=1.0,
55+
type=float,
56+
help="The probability threshold to procedure top_p sampling. ")
57+
parser.add_argument(
58+
"--max_out_len", default=64, type=int, help="Maximum output length. ")
59+
parser.add_argument(
60+
"--min_out_len", default=1, type=int, help="Minimum output length. ")
61+
parser.add_argument(
62+
"--num_return_sequence",
63+
default=1,
64+
type=int,
65+
help="The number of returned sequence. ")
66+
parser.add_argument(
67+
"--temperature",
68+
default=1.0,
69+
type=float,
70+
help="The temperature to set. ")
71+
parser.add_argument(
72+
"--num_return_sequences",
73+
default=1,
74+
type=int,
75+
help="The number of returned sequences. ")
76+
parser.add_argument(
77+
"--use_fp16_decoding",
78+
action="store_true",
79+
help="Whether to use fp16 decoding to predict. ")
80+
parser.add_argument(
81+
"--decoding_strategy",
82+
default="sampling",
83+
choices=["sampling", "beam_search"],
84+
type=str,
85+
help="The main strategy to decode. ")
86+
parser.add_argument(
87+
"--num_beams",
88+
default=4,
89+
type=int,
90+
help="The number of candidate to procedure beam search. ")
91+
parser.add_argument(
92+
"--diversity_rate",
93+
default=0.0,
94+
type=float,
95+
help="The diversity rate to procedure beam search. ")
96+
97+
args = parser.parse_args()
98+
return args
99+
100+
101+
def do_predict(args):
102+
place = "gpu"
103+
place = paddle.set_device(place)
104+
105+
model_name = 'plato-xl'
106+
model = UnifiedTransformerLMHeadModel.from_pretrained(model_name)
107+
tokenizer = UnifiedTransformerTokenizer.from_pretrained(model_name)
108+
109+
plato = FasterUnifiedTransformer(
110+
model=model, use_fp16_decoding=args.use_fp16_decoding)
111+
# Set evaluate mode
112+
plato.eval()
113+
114+
# Convert dygraph model to static graph model
115+
plato = paddle.jit.to_static(
116+
plato,
117+
input_spec=[
118+
# input_ids
119+
paddle.static.InputSpec(
120+
shape=[None, None], dtype="int32"),
121+
# token_type_ids
122+
paddle.static.InputSpec(
123+
shape=[None, None], dtype="int32"),
124+
# attention_mask
125+
paddle.static.InputSpec(
126+
shape=[None, 1, None, None], dtype="float32"),
127+
# seq_len
128+
paddle.static.InputSpec(
129+
shape=[None], dtype="int32"),
130+
# role_ids
131+
paddle.static.InputSpec(
132+
shape=[None, None], dtype="int32"),
133+
# position_ids
134+
paddle.static.InputSpec(
135+
shape=[None, None], dtype="int32"),
136+
args.max_out_len,
137+
args.min_out_len,
138+
args.topk,
139+
args.topp,
140+
args.decoding_strategy,
141+
tokenizer.cls_token_id, # cls/bos
142+
tokenizer.eos_token_id, # eos
143+
tokenizer.pad_token_id, # pad
144+
args.num_beams, # num_beams. Used for beam_search.
145+
args.diversity_rate, # diversity rate. Used for beam search.
146+
args.temperature,
147+
args.num_return_sequences,
148+
])
149+
150+
# Save converted static graph model
151+
paddle.jit.save(plato, os.path.join(args.inference_model_dir, "plato"))
152+
logger.info("PLATO has been saved to {}".format(args.inference_model_dir))
153+
154+
155+
if __name__ == "__main__":
156+
args = parse_args()
157+
pprint(args)
158+
159+
do_predict(args)
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import time
16+
import argparse
17+
import numpy as np
18+
from pprint import pprint
19+
20+
import paddle
21+
import paddle.inference as paddle_infer
22+
23+
from paddlenlp.transformers import UnifiedTransformerLMHeadModel, UnifiedTransformerTokenizer
24+
from paddlenlp.ops.ext_utils import load
25+
26+
27+
def setup_args():
28+
"""Setup arguments."""
29+
parser = argparse.ArgumentParser()
30+
parser.add_argument(
31+
"--inference_model_dir",
32+
default="./infer_model/",
33+
type=str,
34+
help="Path to save inference model of gpt. ")
35+
parser.add_argument(
36+
"--use_role",
37+
action="store_true",
38+
help="Whether to use role embeddings. ")
39+
parser.add_argument(
40+
"--position_style",
41+
default="relative",
42+
choices=["continuous", "relative"],
43+
type=str,
44+
help="The type for positional embedding. Default is continuous. ")
45+
46+
args = parser.parse_args()
47+
48+
return args
49+
50+
51+
def postprocess_response(token_ids, tokenizer):
52+
"""Post-process the decoded sequence. Truncate from the first <eos>."""
53+
eos_pos = len(token_ids)
54+
for i, tok_id in enumerate(token_ids):
55+
if tok_id == tokenizer.sep_token_id:
56+
eos_pos = i
57+
break
58+
token_ids = token_ids[:eos_pos]
59+
tokens = tokenizer.convert_ids_to_tokens(token_ids)
60+
tokens = tokenizer.merge_subword(tokens)
61+
return tokens
62+
63+
64+
def infer(args):
65+
model_name = 'plato-xl'
66+
tokenizer = UnifiedTransformerTokenizer.from_pretrained(model_name)
67+
68+
context = [
69+
"Hi , Becky , what's up ?",
70+
"Not much , except that my mother-in-law is driving me up the wall .",
71+
"What's the problem ?"
72+
]
73+
74+
data = tokenizer.dialogue_encode(
75+
history=context,
76+
add_start_token_as_response=True,
77+
return_length=True,
78+
return_role_ids=args.use_role,
79+
position_style=args.position_style)
80+
81+
# Load FasterTransformer lib.
82+
load("FasterTransformer", verbose=True)
83+
84+
config = paddle_infer.Config(args.inference_model_dir + "plato.pdmodel",
85+
args.inference_model_dir + "plato.pdiparams")
86+
config.enable_use_gpu(100, 0)
87+
config.disable_glog_info()
88+
predictor = paddle_infer.create_predictor(config)
89+
90+
input_handles = {}
91+
for name in predictor.get_input_names():
92+
input_handles[name] = predictor.get_input_handle(name)
93+
if name == "attention_mask":
94+
input_handles[name].copy_from_cpu(
95+
np.asarray(
96+
data[name], dtype="float32").reshape([1, 1, 41, 41]))
97+
else:
98+
input_handles[name].copy_from_cpu(
99+
np.asarray(
100+
data[name], dtype="int32").reshape([1, -1]))
101+
102+
output_handles = [
103+
predictor.get_output_handle(name)
104+
for name in predictor.get_output_names()
105+
]
106+
107+
predictor.run()
108+
109+
output = [output_handle.copy_to_cpu() for output_handle in output_handles]
110+
111+
for sample in output[0].transpose([1, 0]).tolist():
112+
print(" ".join(postprocess_response(sample, tokenizer)))
113+
114+
115+
if __name__ == "__main__":
116+
args = setup_args()
117+
pprint(args)
118+
119+
infer(args)

0 commit comments

Comments
 (0)