Skip to content

Commit 8cdc31c

Browse files
authored
support simple serving for ernie-m text classification (#4435)
* support ernie m * add simple serving for ernie m * add handler in library * change class name
1 parent 03679e2 commit 8cdc31c

File tree

8 files changed

+167
-14
lines changed

8 files changed

+167
-14
lines changed

applications/text_classification/hierarchical/deploy/simple_serving/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ pip install paddlenlp --upgrade
1616
```bash
1717
paddlenlp server server:app --host 0.0.0.0 --port 8189
1818
```
19-
19+
如果是ERNIE-M模型则启动
20+
```bash
21+
paddlenlp server ernie_m_server:app --host 0.0.0.0 --port 8189
22+
```
2023
#### 分类任务发送服务
2124
```bash
2225
python client.py
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# coding:utf-8
2+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from paddlenlp import SimpleServer
16+
from paddlenlp.server import ERNIEMHandler, MultiLabelClassificationPostHandler
17+
18+
app = SimpleServer()
19+
app.register(
20+
"models/cls_hierarchical",
21+
model_path="../../export",
22+
tokenizer_name="ernie-m-base",
23+
model_handler=ERNIEMHandler,
24+
post_handler=MultiLabelClassificationPostHandler,
25+
)

applications/text_classification/multi_class/deploy/simple_serving/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ pip install paddlenlp >= 2.4.4
1616
```bash
1717
paddlenlp server server:app --host 0.0.0.0 --port 8189
1818
```
19-
19+
如果是ERNIE-M模型则启动
20+
```bash
21+
paddlenlp server ernie_m_server:app --host 0.0.0.0 --port 8189
22+
```
2023
#### 启动分类 Client 服务
2124
```bash
2225
python client.py
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# coding:utf-8
2+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from paddlenlp import SimpleServer
17+
from paddlenlp.server import ERNIEMHandler, MultiClassificationPostHandler
18+
19+
app = SimpleServer()
20+
app.register(
21+
"models/cls_multi_class",
22+
model_path="../../export",
23+
tokenizer_name="ernie-m-base",
24+
model_handler=ERNIEMHandler,
25+
post_handler=MultiClassificationPostHandler,
26+
)

applications/text_classification/multi_label/deploy/simple_serving/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ pip install paddlenlp --upgrade
1616
```bash
1717
paddlenlp server server:app --host 0.0.0.0 --port 8189
1818
```
19-
19+
如果是ERNIE-M模型则启动
20+
```bash
21+
paddlenlp server ernie_m_server:app --host 0.0.0.0 --port 8189
22+
```
2023
#### 分类任务发送服务
2124
```bash
2225
python client.py
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# coding:utf-8
2+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from paddlenlp import SimpleServer
17+
from paddlenlp.server import ERNIEMHandler, MultiLabelClassificationPostHandler
18+
19+
app = SimpleServer()
20+
app.register(
21+
"models/cls_multi_label",
22+
model_path="../../export",
23+
tokenizer_name="ernie-m-base",
24+
model_handler=ERNIEMHandler,
25+
post_handler=MultiLabelClassificationPostHandler,
26+
)

paddlenlp/server/handlers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@
1717
MultiClassificationPostHandler,
1818
MultiLabelClassificationPostHandler,
1919
)
20-
from .custom_model_handler import CustomModelHandler
20+
from .custom_model_handler import CustomModelHandler, ERNIEMHandler
2121
from .qa_model_handler import QAModelHandler
2222
from .taskflow_handler import TaskflowHandler
2323
from .token_model_handler import TokenClsModelHandler

paddlenlp/server/handlers/custom_model_handler.py

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# coding:utf-8
2-
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License"
55
# you may not use this file except in compliance with the License.
@@ -13,10 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import numpy as np
16+
17+
from ...data import Pad, Tuple
1618
from .base_handler import BaseModelHandler
17-
from ...transformers import AutoTokenizer
18-
from ...data import Tuple, Pad
19-
from ...utils.log import logger
2019

2120

2221
class CustomModelHandler(BaseModelHandler):
@@ -59,10 +58,11 @@ def process(cls, predictor, tokenizer, data, parameters):
5958
# Seperates data into some batches.
6059
batches = [examples[i : i + batch_size] for i in range(0, len(examples), batch_size)]
6160

62-
batchify_fn = lambda samples, fn=Tuple(
63-
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"), # input
64-
Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"), # segment
65-
): fn(samples)
61+
def batchify_fn(samples):
62+
return Tuple(
63+
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"),
64+
Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"),
65+
)(samples)
6666

6767
results = [[]] * predictor._output_num
6868
for batch in batches:
@@ -75,7 +75,74 @@ def process(cls, predictor, tokenizer, data, parameters):
7575
for i, out in enumerate(output):
7676
results[i].append(out)
7777
else:
78-
outputs = predictor._predictor.run(None, {"input_ids": input_ids, "token_type_ids": token_type_ids})
78+
predictor._predictor.run(None, {"input_ids": input_ids, "token_type_ids": token_type_ids})
79+
for i, out in enumerate(output):
80+
results[i].append(out)
81+
82+
# Resolve the logits result and get the predict label and confidence
83+
results_concat = []
84+
for i in range(0, len(results)):
85+
results_concat.append(np.concatenate(results[i], axis=0))
86+
out_dict = {"logits": results_concat[0].tolist(), "data": data}
87+
for i in range(1, len(results_concat)):
88+
out_dict[f"logits_{i}"] = results_concat[i].tolist()
89+
return out_dict
90+
91+
92+
class ERNIEMHandler(BaseModelHandler):
93+
def __init__(self):
94+
super().__init__()
95+
96+
@classmethod
97+
def process(cls, predictor, tokenizer, data, parameters):
98+
max_seq_len = 128
99+
batch_size = 1
100+
if "max_seq_len" not in parameters:
101+
max_seq_len = parameters["max_seq_len"]
102+
if "batch_size" not in parameters:
103+
batch_size = parameters["batch_size"]
104+
text = None
105+
if "text" in data:
106+
text = data["text"]
107+
if text is None:
108+
return {}
109+
if isinstance(text, str):
110+
text = [text]
111+
has_pair = False
112+
if "text_pair" in data and data["text_pair"] is not None:
113+
text_pair = data["text_pair"]
114+
if isinstance(text_pair, str):
115+
text_pair = [text_pair]
116+
if len(text) != len(text_pair):
117+
raise ValueError("The length of text and text_pair must be same.")
118+
has_pair = True
119+
120+
# Get the result of tokenizer
121+
examples = []
122+
for idx, data in enumerate(text):
123+
if has_pair:
124+
result = tokenizer(text=text[idx], text_pair=text_pair[idx], max_length=max_seq_len)
125+
else:
126+
result = tokenizer(text=text[idx], max_length=max_seq_len)
127+
examples.append(result["input_ids"])
128+
129+
# Seperates data into some batches.
130+
batches = [examples[i : i + batch_size] for i in range(0, len(examples), batch_size)]
131+
132+
def batchify_fn(samples):
133+
return Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64")(samples)
134+
135+
results = [[]] * predictor._output_num
136+
for batch in batches:
137+
input_ids = batchify_fn(batch)
138+
if predictor._predictor_type == "paddle_inference":
139+
predictor._input_handles[0].copy_from_cpu(input_ids)
140+
predictor._predictor.run()
141+
output = [output_handle.copy_to_cpu() for output_handle in predictor._output_handles]
142+
for i, out in enumerate(output):
143+
results[i].append(out)
144+
else:
145+
predictor._predictor.run(None, {"input_ids": input_ids})
79146
for i, out in enumerate(output):
80147
results[i].append(out)
81148

0 commit comments

Comments
 (0)