Skip to content

Commit f21829c

Browse files
authored
Paddle2ONNX deploy script (#2149)
* add paddle2onnx deploy script * update doc * update doc * simplify * update doc * revert doc * update doc * update paddle2onnx doc * update outer doc * update paddle2onnx doc * remove requirments in doc * update paddle2onnx doc * update * paddle2onnx doc
1 parent ba57969 commit f21829c

File tree

5 files changed

+367
-2
lines changed

5 files changed

+367
-2
lines changed

model_zoo/ernie-3.0/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
* [Python部署指南](#Python部署指南)
1818
* [服务化部署](#服务化部署)
1919
* [环境依赖](#环境依赖)
20+
* [Paddle2ONNX 部署](#Paddle2ONNX部署)
21+
* [ONNX导出及ONNXRuntime部署](#ONNX导出及ONNXRuntime部署)
2022

2123

2224

@@ -415,6 +417,9 @@ TBD
415417
│ └── token_cls_rpc_client.py
416418
│ └── token_cls_service.py
417419
│ └── token_cls_config.yml
420+
│ └── paddle2onnx
421+
│ └── ernie_predictor.py
422+
│ └── infer.py
418423
└── README.md # 文档,本文件
419424

420425
```
@@ -599,6 +604,10 @@ TBD
599604
<a name="部署"></a>
600605

601606
## 部署
607+
我们为ERNIE 3.0提供了多种部署方案,可以满足不同场景下的部署需求,请根据实际情况进行选择。
608+
<p align="center">
609+
<img width="700" alt="image" src="https://user-images.githubusercontent.com/30516196/168466069-e8162235-2f06-4a2d-b78f-d9afd437c620.png">
610+
</p>
602611

603612
<a name="Python部署"></a>
604613

@@ -613,7 +622,12 @@ Python部署请参考:[Python部署指南](./deploy/python/README.md)
613622
### 服务化部署
614623
TBD
615624

625+
<a name="Paddle2ONNX部署"></a>
626+
627+
### Paddle2ONNX 部署
616628

629+
<a name="ONNX导出及ONNXRuntime部署"></a>
630+
ONNX导出及ONNXRuntime部署请参考:[ONNX导出及ONNXRuntime部署指南](./deploy/paddle2onnx/README.md)
617631
## Reference
618632

619633
* Sun Y, Wang S, Feng S, et al. ERNIE 3.0: Large-scale Knowledge Enhanced Pre-training for Language Understanding and Generation[J]. arXiv preprint arXiv:2107.02137, 2021.
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# ERNIE 3.0 ONNX导出及部署指南
2+
本文介绍ERNIE 3.0 模型模型如何转化为ONNX模型,并基于ONNXRuntime引擎部署,本文将以命名实体识别和分类两大场景作为介绍示例。
3+
- [ERNIE 3.0 ONNX导出及部署指南](#ERNIE3.0ONNX导出及部署指南)
4+
- [1. 环境准备](#1-环境准备)
5+
- [2. 命名实体识别模型推理](#2-命名实体识别模型推理)
6+
- [2.1 模型获取](#21-模型获取)
7+
- [2.2 模型转换](#22-模型转换)
8+
- [2.3 ONNXRuntime推理样例](#23-ONNXRuntime推理样例)
9+
- [3. 分类模型推理](#3-分类模型推理)
10+
- [3.1 模型获取](#31-模型获取)
11+
- [3.2 模型转换](#32-模型转换)
12+
- [3.3 ONNXRuntime推理样例](#33-ONNXRuntime推理样例)
13+
## 1. 环境准备
14+
ERNIE 3.0模型转换与ONNXRuntime预测部署依赖Paddle2ONNX和ONNXRuntime,Paddle2ONNX支持将Paddle模型转化为ONNX模型格式,算子目前稳定支持导出ONNX Opset 7~15,更多细节可参考:[Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX)
15+
如果基于CPU部署,请使用如下命令安装所需依赖:
16+
```
17+
python -m pip install onnxruntime
18+
```
19+
如果基于GPU部署,请先确保机器已正确安装NVIDIA相关驱动和基础软件,确保CUDA >= 11.2,CuDNN >= 8.2,并使用以下命令安装所需依赖:
20+
```
21+
python -m pip install onnxruntime-gpu
22+
```
23+
24+
## 2. 命名实体识别模型推理
25+
### 2.1 模型获取
26+
用户可使用自己训练的模型进行推理,具体训练调优方法可参考[模型训练调优](./../../README.md#微调),也可以使用我们提供的msra_ner数据集训练的ERNIE 3.0模型,请执行如下命令获取模型:
27+
```
28+
# 获取命名实体识别FP32模型
29+
wget https://paddlenlp.bj.bcebos.com/models/transformers/ernie_3.0/msra_ner_pruned_infer_model.zip
30+
unzip msra_ner_pruned_infer_model.zip
31+
```
32+
### 2.2 模型转换
33+
使用Paddle2ONNX将Paddle静态图模型转换为ONNX模型格式的命令如下,以下命令成功运行后,将会在当前目录下生成ner_model.onnx模型文件。
34+
```
35+
paddle2onnx --model_dir msra_ner_pruned_infer_model/ --model_filename float32.pdmodel --params_filename float32.pdiparams --save_file ner_model.onnx --opset_version 13 --enable_onnx_checker True
36+
```
37+
Paddle2ONNX的命令行参数说明请查阅:[Paddle2ONNX命令行参数说明](https://github.com/PaddlePaddle/Paddle2ONNX)
38+
39+
### 2.3 ONNXRuntime推理样例
40+
请使用如下命令进行部署
41+
```
42+
python infer.py --task_name token_cls --model_path ner_model.onnx
43+
```
44+
输出打印如下:
45+
```
46+
input data: 北京的涮肉,重庆的火锅,成都的小吃都是极具特色的美食。
47+
The model detects all entities:
48+
entity: 北京 label: LOC pos: [0, 1]
49+
entity: 重庆 label: LOC pos: [6, 7]
50+
entity: 成都 label: LOC pos: [12, 13]
51+
-----------------------------
52+
input data: 乔丹、科比、詹姆斯和姚明都是篮球界的标志性人物。
53+
The model detects all entities:
54+
entity: 乔丹 label: PER pos: [0, 1]
55+
entity: 科比 label: PER pos: [3, 4]
56+
entity: 詹姆斯 label: PER pos: [6, 8]
57+
entity: 姚明 label: PER pos: [10, 11]
58+
-----------------------------
59+
```
60+
infer.py脚本中的参数说明:
61+
| 参数 |参数说明 |
62+
|----------|--------------|
63+
|--task_name | 配置任务名称,可选seq_cls和token_cls,默认为seq_cls|
64+
|--model_name_or_path | 模型的路径或者名字,默认为ernie-3.0-medium-zh|
65+
|--model_path | 用于推理的ONNX模型的路径|
66+
|--max_seq_length |最大序列长度,默认为128|
67+
68+
## 3. 分类模型推理
69+
### 3.1 模型获取
70+
用户可使用自己训练的模型进行推理,具体训练调优方法可参考[模型训练调优](./../../README.md#微调),也可以使用我们提供的tnews数据集训练的ERNIE 3.0模型,请执行如下命令获取模型:
71+
```
72+
# 分类模型模型:
73+
wget https://paddlenlp.bj.bcebos.com/models/transformers/ernie_3.0/tnews_pruned_infer_model.zip
74+
unzip tnews_pruned_infer_model.zip
75+
```
76+
### 3.2 模型转换
77+
使用Paddle2ONNX将Paddle静态图模型转换为ONNX模型格式的命令如下,以下命令成功运行后,将会在当前目录下生成tnews_model.onnx模型文件。
78+
```
79+
paddle2onnx --model_dir tnews_pruned_infer_model/ --model_filename float32.pdmodel --params_filename float32.pdiparams --save_file tnews_model.onnx --opset_version 13 --enable_onnx_checker True
80+
```
81+
Paddle2ONNX的命令行参数说明请查阅:[Paddle2ONNX命令行参数说明](https://github.com/PaddlePaddle/Paddle2ONNX)
82+
83+
### 3.3 ONNXRuntime推理样例
84+
请使用如下命令进行部署
85+
```
86+
python infer.py --task_name seq_cls --model_path tnews_model.onnx
87+
```
88+
输出打印如下:
89+
```
90+
input data: 未来自动驾驶真的会让酒驾和疲劳驾驶成历史吗?
91+
seq cls result:
92+
label: news_car confidence: 0.554353654384613
93+
-----------------------------
94+
input data: 黄磊接受华少快问快答,不光智商逆天,情商也不逊黄渤
95+
seq cls result:
96+
label: news_entertainment confidence: 0.9495906829833984
97+
-----------------------------
98+
```
99+
infer.py脚本中的参数说明:
100+
| 参数 |参数说明 |
101+
|----------|--------------|
102+
|--task_name | 配置任务名称,可选seq_cls和token_cls,默认为seq_cls|
103+
|--model_name_or_path | 模型的路径或者名字,默认为ernie-3.0-medium-zh|
104+
|--model_path | 用于推理的ONNX模型的路径|
105+
|--max_seq_length |最大序列长度,默认为128|
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import six
16+
import os
17+
import numpy as np
18+
import paddle
19+
import onnxruntime as ort
20+
from paddlenlp.transformers import AutoTokenizer
21+
22+
23+
class InferBackend(object):
24+
def __init__(self, model_path):
25+
print(">>> [InferBackend] Creating Engine ...")
26+
providers = ['CUDAExecutionProvider']
27+
sess_options = ort.SessionOptions()
28+
self.predictor = ort.InferenceSession(
29+
model_path, sess_options=sess_options, providers=providers)
30+
if "CUDAExecutionProvider" in self.predictor.get_providers():
31+
print(">>> [InferBackend] Use GPU to inference ...")
32+
else:
33+
print(">>> [InferBackend] Use CPU to inference ...")
34+
input_name1 = self.predictor.get_inputs()[1].name
35+
input_name2 = self.predictor.get_inputs()[0].name
36+
self.input_handles = [input_name1, input_name2]
37+
print(">>> [InferBackend] Engine Created ...")
38+
39+
def infer(self, input_dict: dict):
40+
result = self.predictor.run(None, input_dict)
41+
return result
42+
43+
44+
def token_cls_print_ret(infer_result, input_datas):
45+
rets = infer_result["value"]
46+
for i, ret in enumerate(rets):
47+
print("input data:", input_datas[i])
48+
print("The model detects all entities:")
49+
for iterm in ret:
50+
print("entity:", iterm["entity"], " label:", iterm["label"],
51+
" pos:", iterm["pos"])
52+
print("-----------------------------")
53+
54+
55+
def seq_cls_print_ret(infer_result, input_datas):
56+
label_list = [
57+
"news_story", "news_culture", "news_entertainment", "news_sports",
58+
"news_finance", "news_house", "news_car", "news_edu", "news_tech",
59+
"news_military", "news_travel", "news_world", "news_stock",
60+
"news_agriculture", "news_game"
61+
]
62+
label = infer_result["label"].squeeze().tolist()
63+
confidence = infer_result["confidence"].squeeze().tolist()
64+
for i, ret in enumerate(infer_result):
65+
print("input data:", input_datas[i])
66+
print("seq cls result:")
67+
print("label:", label_list[label[i]], " confidence:", confidence[i])
68+
print("-----------------------------")
69+
70+
71+
class ErniePredictor(object):
72+
def __init__(self, args):
73+
self.task_name = args.task_name
74+
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
75+
if args.task_name == 'seq_cls':
76+
self.label_names = []
77+
self.preprocess = self.seq_cls_preprocess
78+
self.postprocess = self.seq_cls_postprocess
79+
self.printer = seq_cls_print_ret
80+
elif args.task_name == 'token_cls':
81+
self.label_names = [
82+
'O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC'
83+
]
84+
self.preprocess = self.token_cls_preprocess
85+
self.postprocess = self.token_cls_postprocess
86+
self.printer = token_cls_print_ret
87+
else:
88+
print(
89+
"[ErniePredictor]: task_name only support seq_cls and token_cls now."
90+
)
91+
exit(0)
92+
93+
self.max_seq_length = args.max_seq_length
94+
self.inference_backend = InferBackend(args.model_path)
95+
96+
def seq_cls_preprocess(self, input_data: list):
97+
data = input_data
98+
# tokenizer + pad
99+
data = self.tokenizer(
100+
data, max_length=self.max_seq_length, padding=True, truncation=True)
101+
input_ids = data["input_ids"]
102+
token_type_ids = data["token_type_ids"]
103+
return {
104+
"input_ids": np.array(
105+
input_ids, dtype="int64"),
106+
"token_type_ids": np.array(
107+
token_type_ids, dtype="int64")
108+
}
109+
110+
def seq_cls_postprocess(self, infer_data, input_data):
111+
logits = np.array(infer_data[0])
112+
max_value = np.max(logits, axis=1, keepdims=True)
113+
exp_data = np.exp(logits - max_value)
114+
probs = exp_data / np.sum(exp_data, axis=1, keepdims=True)
115+
out_dict = {
116+
"label": probs.argmax(axis=-1),
117+
"confidence": probs.max(axis=-1)
118+
}
119+
return out_dict
120+
121+
def token_cls_preprocess(self, data: list):
122+
# tokenizer + pad
123+
is_split_into_words = False
124+
if isinstance(data[0], list):
125+
is_split_into_words = True
126+
data = self.tokenizer(
127+
data,
128+
max_length=self.max_seq_length,
129+
padding=True,
130+
truncation=True,
131+
is_split_into_words=is_split_into_words)
132+
133+
input_ids = data["input_ids"]
134+
token_type_ids = data["token_type_ids"]
135+
return {
136+
"input_ids": np.array(
137+
input_ids, dtype="int64"),
138+
"token_type_ids": np.array(
139+
token_type_ids, dtype="int64")
140+
}
141+
142+
def token_cls_postprocess(self, infer_data, input_data):
143+
result = np.array(infer_data[0])
144+
tokens_label = result.argmax(axis=-1).tolist()
145+
# 获取batch中每个token的实体
146+
value = []
147+
for batch, token_label in enumerate(tokens_label):
148+
start = -1
149+
label_name = ""
150+
items = []
151+
for i, label in enumerate(token_label):
152+
if self.label_names[label] == "O" and start >= 0:
153+
entity = input_data[batch][start:i - 1]
154+
if isinstance(entity, list):
155+
entity = "".join(entity)
156+
items.append({
157+
"pos": [start, i - 2],
158+
"entity": entity,
159+
"label": label_name,
160+
})
161+
start = -1
162+
elif "B-" in self.label_names[label]:
163+
start = i - 1
164+
label_name = self.label_names[label][2:]
165+
if start >= 0:
166+
items.append({
167+
"pos": [start, len(token_label) - 1],
168+
"entity": input_data[batch][start:len(token_label) - 1],
169+
"label": ""
170+
})
171+
value.append(items)
172+
173+
out_dict = {"value": value, "tokens_label": tokens_label}
174+
return out_dict
175+
176+
def infer(self, data):
177+
return self.inference_backend.infer(data)
178+
179+
def predict(self, input_data: list):
180+
preprocess_result = self.preprocess(input_data)
181+
infer_result = self.infer(preprocess_result)
182+
result = self.postprocess(infer_result, input_data)
183+
self.printer(result, input_data)
184+
return result
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
from ernie_predictor import ErniePredictor
17+
import argparse
18+
19+
20+
def parse_args():
21+
parser = argparse.ArgumentParser()
22+
# Required parameters
23+
parser.add_argument(
24+
"--task_name",
25+
default='seq_cls',
26+
type=str,
27+
help="The name of the task to perform predict, selected in: seq_cls and token_cls"
28+
)
29+
parser.add_argument(
30+
"--model_name_or_path",
31+
default="ernie-3.0-medium-zh",
32+
type=str,
33+
help="The directory or name of model.", )
34+
parser.add_argument(
35+
"--model_path",
36+
type=str,
37+
required=True,
38+
help="The path prefix of inference model to be used.", )
39+
parser.add_argument(
40+
"--max_seq_length",
41+
default=128,
42+
type=int,
43+
help="The maximum total input sequence length after tokenization. Sequences longer "
44+
"than this will be truncated, sequences shorter will be padded.", )
45+
args = parser.parse_args()
46+
return args
47+
48+
49+
def main():
50+
args = parse_args()
51+
predictor = ErniePredictor(args)
52+
53+
if args.task_name == 'seq_cls':
54+
text = ["未来自动驾驶真的会让酒驾和疲劳驾驶成历史吗?", "黄磊接受华少快问快答,不光智商逆天,情商也不逊黄渤"]
55+
elif args.task_name == 'token_cls':
56+
text = ["北京的涮肉,重庆的火锅,成都的小吃都是极具特色的美食。", "乔丹、科比、詹姆斯和姚明都是篮球界的标志性人物。"]
57+
58+
outputs = predictor.predict(text)
59+
60+
61+
if __name__ == "__main__":
62+
main()

0 commit comments

Comments
 (0)