Skip to content

Commit c985432

Browse files
authored
Add text classification serving example (#499)
* add serving scripts * add serving deploy scripts * rm deploy/serving/README.md * update docs * update docs * update serving usage docs
1 parent 53f181d commit c985432

File tree

3 files changed

+318
-0
lines changed

3 files changed

+318
-0
lines changed

examples/text_classification/pretrained_models/README.md

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ pretrained_models/
5555
├── deploy # 部署
5656
│   └── python
5757
│   └── predict.py # python预测部署示例
58+
│   └── serving
59+
│   ├── client.py # 客户端预测脚本
60+
│   └── export_servable_model.py # 导出Serving模型及其配置
5861
├── export_model.py # 动态图参数导出静态图参数脚本
5962
├── predict.py # 预测脚本
6063
├── README.md # 使用说明
@@ -176,3 +179,97 @@ Data: 这个宾馆比较陈旧了,特价的房间也很一般。总体来说
176179
Data: 怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片 Label: negative
177180
Data: 作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。 Label: positive
178181
```
182+
183+
184+
## 使用Paddle Serving API进行推理部署
185+
186+
**NOTE:**
187+
188+
使用Paddle Serving服务化部署需要将动态图保存的模型参数导出为静态图Inference模型参数文件。如何导出模型参考上述提到的**导出模型**
189+
190+
Inference模型参数文件:
191+
| 文件 | 说明 |
192+
|-------------------------------|----------------------------------------|
193+
| static_graph_params.pdiparams | 模型权重文件,供推理时加载使用 |
194+
| static_graph_params.pdmodel | 模型结构文件,供推理时加载使用 |
195+
196+
197+
### 依赖安装
198+
199+
* 服务器端依赖:
200+
201+
```shell
202+
pip install paddle-serving-app paddle-serving-client paddle-serving-server==0.5.0
203+
```
204+
205+
如果服务器端可以使用GPU进行推理,则安装server的gpu版本,安装时要注意参考服务器当前CUDA、TensorRT的版本来安装对应的版本:[Serving readme](https://github.com/PaddlePaddle/Serving/tree/v0.5.0)
206+
207+
```shell
208+
pip install paddle-serving-app paddle-serving-client paddle-serving-server-gpu==0.5.0
209+
```
210+
211+
* 客户端依赖:
212+
213+
```shell
214+
pip install paddle-serving-app paddle-serving-client
215+
```
216+
217+
建议在**docker**容器中运行服务器端和客户端以避免一些系统依赖库问题,启动docker镜像的命令参考:[Serving readme](https://github.com/PaddlePaddle/Serving/tree/v0.5.0)
218+
219+
### Serving的模型和配置导出
220+
221+
使用Serving进行预测部署时,需要将静态图inference model导出为Serving可读入的模型参数和配置。运行方式如下:
222+
223+
```shell
224+
python -u deploy/serving/export_servable_model.py \
225+
--inference_model_dir ./ \
226+
--model_file static_graph_params.pdmodel \
227+
--params_file static_graph_params.pdiparams
228+
```
229+
230+
可支持配置的参数:
231+
* `inference_model_dir`: Inference推理模型所在目录,这里假设为当前目录。
232+
* `model_file`: 推理需要加载的模型结构文件。
233+
* `params_file`: 推理需要加载的模型权重文件。
234+
235+
执行命令后,会在当前目录下生成2个目录:serving_server 和 serving_client。serving_server目录包含服务器端所需的模型和配置,需将其拷贝到服务器端容器中;serving_client目录包含客户端所需的配置,需将其拷贝到客户端容器中。
236+
237+
### 服务器启动server
238+
239+
在服务器端容器中,启动server
240+
241+
```shell
242+
python -m deploy/serving/paddle_serving_server_gpu.serve \
243+
--model ./serving_server \
244+
--port 8090
245+
```
246+
其中:
247+
* `model`: server加载的模型和配置所在目录。
248+
* `port`: 表示server开启的服务端口8090。
249+
250+
如果服务器端可以使用GPU进行推理计算,则启动服务器时可以配置server使用的GPU id
251+
252+
```shell
253+
python -m paddle_serving_server_gpu.serve \
254+
--model ./serving_server \
255+
--port 8090 \
256+
--gpu_id 0
257+
```
258+
* `gpu_id`: server使用0号GPU。
259+
260+
261+
### 客服端发送推理请求
262+
263+
在客户端容器中,使用前面得到的serving_client目录启动client发起RPC推理请求。和使用Paddle Inference API进行推理一样。
264+
265+
### 从命令行读取输入数据发起推理请求
266+
```shell
267+
python deploy/serving/client.py \
268+
--client_config_file ./serving_client/serving_client_conf.prototxt \
269+
--server_ip_port 127.0.0.1:8090 \
270+
--max_seq_length 128
271+
```
272+
其中参数释义如下:
273+
- `client_config_file` 表示客户端需要加载的配置文件。
274+
- `server_ip_port` 表示服务器端的ip地址和端口号。ip地址和端口号需要根据实际情况进行更换。
275+
- `max_seq_length` 表示输入的最大句子长度,超过该长度将被截断。
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import time
17+
import numpy as np
18+
import os
19+
20+
import paddle
21+
from paddlenlp.data import Stack, Tuple, Pad
22+
from paddlenlp.transformers import ErnieTinyTokenizer
23+
from paddle_serving_client import Client
24+
from scipy.special import softmax
25+
26+
parser = argparse.ArgumentParser()
27+
parser.add_argument(
28+
"--client_config_file",
29+
type=str,
30+
default="./serving_client/serving_client_conf.prototxt",
31+
help="Client prototxt config file.")
32+
parser.add_argument(
33+
"--server_ip_port",
34+
type=str,
35+
default="127.0.0.1:8090",
36+
help="The ip address and port of the server.")
37+
parser.add_argument(
38+
"--batch_size",
39+
type=int,
40+
default=1,
41+
help="Batch size per GPU/CPU for training.")
42+
parser.add_argument(
43+
"--max_seq_length",
44+
type=int,
45+
default=128,
46+
help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
47+
)
48+
args = parser.parse_args()
49+
50+
51+
def convert_example(example,
52+
tokenizer,
53+
label_list,
54+
max_seq_length=512,
55+
is_test=False):
56+
"""
57+
Builds model inputs from a sequence or a pair of sequence for sequence classification tasks
58+
by concatenating and adding special tokens. And creates a mask from the two sequences passed
59+
to be used in a sequence-pair classification task.
60+
61+
A BERT sequence has the following format:
62+
63+
- single sequence: ``[CLS] X [SEP]``
64+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
65+
66+
A BERT sequence pair mask has the following format:
67+
::
68+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
69+
| first sequence | second sequence |
70+
71+
If only one sequence, only returns the first portion of the mask (0's).
72+
73+
74+
Args:
75+
example(obj:`list[str]`): List of input data, containing text and label if it have label.
76+
tokenizer(obj:`PretrainedTokenizer`): This tokenizer inherits from :class:`~paddlenlp.transformers.PretrainedTokenizer`
77+
which contains most of the methods. Users should refer to the superclass for more information regarding methods.
78+
label_list(obj:`list[str]`): All the labels that the data has.
79+
max_seq_len(obj:`int`): The maximum total input sequence length after tokenization.
80+
Sequences longer than this will be truncated, sequences shorter will be padded.
81+
is_test(obj:`False`, defaults to `False`): Whether the example contains label or not.
82+
83+
Returns:
84+
input_ids(obj:`list[int]`): The list of token ids.
85+
token_type_ids(obj: `list[int]`): List of sequence pair mask.
86+
label(obj:`numpy.array`, data type of int64, optional): The input label if not is_test.
87+
"""
88+
text = example
89+
encoded_inputs = tokenizer(text=text, max_seq_len=max_seq_length)
90+
input_ids = encoded_inputs["input_ids"]
91+
token_type_ids = encoded_inputs["token_type_ids"]
92+
93+
if not is_test:
94+
# create label maps
95+
label_map = {}
96+
for (i, l) in enumerate(label_list):
97+
label_map[l] = i
98+
99+
label = label_map[label]
100+
label = np.array([label], dtype="int64")
101+
return input_ids, token_type_ids, label
102+
else:
103+
return input_ids, token_type_ids
104+
105+
106+
def predict(data, label_map, batch_size):
107+
"""
108+
Args:
109+
sentences (list[str]): each string is a sentence. If have sentences then no need paths
110+
paths (list[str]): The paths of file which contain sentences. If have paths then no need sentences
111+
Returns:
112+
res (list(numpy.ndarray)): The result of sentence, indicate whether each word is replaced, same shape with sentences.
113+
"""
114+
115+
# initialize client
116+
client = Client()
117+
client.load_client_config(args.client_config_file)
118+
client.connect([args.server_ip_port])
119+
120+
# TODO: Text tokenization which is done in the serving end not the client end may be better.
121+
tokenizer = ErnieTinyTokenizer.from_pretrained("ernie-tiny")
122+
examples = []
123+
for text in data:
124+
input_ids, token_type_ids = convert_example(
125+
text,
126+
tokenizer,
127+
label_list=label_map.values(),
128+
max_seq_length=args.max_seq_length,
129+
is_test=True)
130+
examples.append((input_ids, token_type_ids))
131+
132+
batchify_fn = lambda samples, fn=Tuple(
133+
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'), # input ids
134+
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'), # token type ids
135+
): fn(samples)
136+
137+
# Seperates data into some batches.
138+
batches = [
139+
examples[idx:idx + batch_size]
140+
for idx in range(0, len(examples), batch_size)
141+
]
142+
143+
results = []
144+
for batch in batches:
145+
input_ids, token_type_ids = batchify_fn(batch)
146+
fetch_map = client.predict(
147+
feed={"input_ids": input_ids,
148+
"token_type_ids": token_type_ids},
149+
fetch=["save_infer_model/scale_0.tmp_1"],
150+
batch=True)
151+
output_data = np.array(fetch_map["save_infer_model/scale_0.tmp_1"])
152+
probs = softmax(output_data, axis=1)
153+
idx = np.argmax(probs, axis=1)
154+
idx = idx.tolist()
155+
labels = [label_map[i] for i in idx]
156+
results.extend(labels)
157+
158+
return results
159+
160+
161+
if __name__ == '__main__':
162+
paddle.enable_static()
163+
data = [
164+
'这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般',
165+
'怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片',
166+
'作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。',
167+
]
168+
label_map = {0: 'negative', 1: 'positive'}
169+
results = predict(data, label_map, args.batch_size)
170+
for idx, text in enumerate(data):
171+
print('Data: {} \t Label: {}'.format(text, results[idx]))
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import paddle
17+
import paddle_serving_client.io as serving_io
18+
19+
20+
def parse_args():
21+
parser = argparse.ArgumentParser()
22+
parser.add_argument(
23+
"--inference_model_dir",
24+
type=str,
25+
default="./",
26+
help="The directory of the inference model.")
27+
parser.add_argument(
28+
"--model_file",
29+
type=str,
30+
default='./static_graph_params.pdmodel',
31+
help="The inference model file name.")
32+
parser.add_argument(
33+
"--params_file",
34+
type=str,
35+
default='./static_graph_params.pdiparams',
36+
help="The input inference parameters file name.")
37+
return parser.parse_args()
38+
39+
40+
if __name__ == '__main__':
41+
paddle.enable_static()
42+
args = parse_args()
43+
feed_names, fetch_names = serving_io.inference_model_to_serving(
44+
dirname=args.inference_model_dir,
45+
serving_server="serving_server",
46+
serving_client="serving_client",
47+
model_filename=args.model_file,
48+
params_filename=args.params_file)
49+
print("model feed_names : %s" % feed_names)
50+
print("model fetch_names : %s" % fetch_names)

0 commit comments

Comments
 (0)