Skip to content

Commit f982df8

Browse files
authored
fix paddle2onnx postprocess bug (#2386)
1 parent c367f8f commit f982df8

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

model_zoo/ernie-3.0/deploy/paddle2onnx/ernie_predictor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ def token_cls_postprocess(self, infer_data, input_data):
167167
label_name = ""
168168
items = []
169169
for i, label in enumerate(token_label):
170-
if self.label_names[label] == "O" and start >= 0:
170+
if (self.label_names[label] == "O" or
171+
"B-" in self.label_names[label]) and start >= 0:
171172
entity = input_data[batch][start:i - 1]
172173
if isinstance(entity, list):
173174
entity = "".join(entity)
@@ -177,7 +178,7 @@ def token_cls_postprocess(self, infer_data, input_data):
177178
"label": label_name,
178179
})
179180
start = -1
180-
elif "B-" in self.label_names[label]:
181+
if "B-" in self.label_names[label]:
181182
start = i - 1
182183
label_name = self.label_names[label][2:]
183184
if start >= 0:

model_zoo/ernie-3.0/deploy/python/ernie_predictor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def __init__(self,
9494
enable_onnx_checker=True)
9595
dynamic_quantize_model = onnx_model
9696
providers = ['CUDAExecutionProvider']
97+
if device == 'cpu':
98+
providers = ['CPUExecutionProvider']
9799
if use_quantize:
98100
float_onnx_file = "model.onnx"
99101
with open(float_onnx_file, "wb") as f:
@@ -103,7 +105,6 @@ def __init__(self,
103105
providers = ['CPUExecutionProvider']
104106
sess_options = ort.SessionOptions()
105107
sess_options.intra_op_num_threads = num_threads
106-
sess_options.inter_op_num_threads = num_threads
107108
self.predictor = ort.InferenceSession(
108109
dynamic_quantize_model,
109110
sess_options=sess_options,

0 commit comments

Comments
 (0)