Skip to content

Commit d5fd5ae

Browse files
authored
[Feat] OCR and PP-StructureV3 high-performance servers support dynamic batching (#4706)
* Fix * Enhance installation
1 parent 54332b3 commit d5fd5ae

File tree

13 files changed

+465
-74
lines changed

13 files changed

+465
-74
lines changed

deploy/hps/sdk/pipelines/OCR/server/model_repo/ocr/1/model.py

Lines changed: 140 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
16+
from concurrent.futures import ThreadPoolExecutor
17+
from operator import itemgetter
1518
from typing import Any, Dict, Final, List, Tuple
1619

1720
from paddlex_hps_server import (
@@ -27,6 +30,17 @@
2730
_DEFAULT_MAX_OUTPUT_IMG_SIZE: Final[Tuple[int, int]] = (2000, 2000)
2831

2932

33+
class _SequentialExecutor(object):
34+
def map(self, fn, *iterables):
35+
return map(fn, *iterables)
36+
37+
def __enter__(self):
38+
return self
39+
40+
def __exit__(self, exc_type, exc_value, traceback):
41+
pass
42+
43+
3044
class TritonPythonModel(BaseTritonPythonModel):
3145
def initialize(self, args):
3246
super().initialize(args)
@@ -68,6 +82,129 @@ def get_result_model_type(self):
6882
return schemas.ocr.InferResult
6983

7084
def run(self, input, log_id):
85+
return self.run_batch([input], [log_id], log_id)
86+
87+
def run_batch(self, inputs, log_ids, batch_id):
88+
result_or_output_dic = {}
89+
90+
input_groups = self._group_inputs(inputs)
91+
92+
max_group_size = max(map(len, input_groups))
93+
if max_group_size > 1:
94+
executor = ThreadPoolExecutor(max_workers=max_group_size)
95+
else:
96+
executor = _SequentialExecutor()
97+
98+
with executor:
99+
for input_group in input_groups:
100+
input_ids_g = list(map(itemgetter(0), input_group))
101+
inputs_g = list(map(itemgetter(1), input_group))
102+
103+
log_ids_g = [log_ids[i] for i in input_ids_g]
104+
105+
ret = executor.map(self._preprocess, inputs_g, log_ids_g)
106+
ind_img_lsts, ind_data_info_lst, ind_visualize_enabled_lst = [], [], []
107+
for i, item in enumerate(ret):
108+
if isinstance(item, tuple):
109+
assert len(item) == 3, len(item)
110+
ind_img_lsts.append(item[0])
111+
ind_data_info_lst.append(item[1])
112+
ind_visualize_enabled_lst.append(item[2])
113+
else:
114+
input_id = input_ids_g[i]
115+
result_or_output_dic[input_id] = item
116+
117+
if len(ind_img_lsts):
118+
images = [img for item in ind_img_lsts for img in item]
119+
preds = list(
120+
self.pipeline(
121+
images,
122+
use_doc_orientation_classify=inputs_g[
123+
0
124+
].useDocOrientationClassify,
125+
use_doc_unwarping=inputs_g[0].useDocUnwarping,
126+
use_textline_orientation=inputs_g[0].useTextlineOrientation,
127+
text_det_limit_side_len=inputs_g[0].textDetLimitSideLen,
128+
text_det_limit_type=inputs_g[0].textDetLimitType,
129+
text_det_thresh=inputs_g[0].textDetThresh,
130+
text_det_box_thresh=inputs_g[0].textDetBoxThresh,
131+
text_det_unclip_ratio=inputs_g[0].textDetUnclipRatio,
132+
text_rec_score_thresh=inputs_g[0].textRecScoreThresh,
133+
return_word_box=inputs_g[0].returnWordBox,
134+
)
135+
)
136+
137+
if len(preds) != len(images):
138+
raise RuntimeError(
139+
f"The number of predictions ({len(preds)}) is not the same as the number of input images ({len(images)})."
140+
)
141+
142+
start_idx = 0
143+
ind_preds = []
144+
for item in ind_img_lsts:
145+
ind_preds.append(preds[start_idx : start_idx + len(item)])
146+
start_idx += len(item)
147+
148+
for i, result in zip(
149+
input_ids_g,
150+
executor.map(
151+
self._postprocess,
152+
ind_img_lsts,
153+
ind_data_info_lst,
154+
ind_visualize_enabled_lst,
155+
ind_preds,
156+
log_ids_g,
157+
inputs_g,
158+
),
159+
):
160+
result_or_output_dic[i] = result
161+
162+
assert len(result_or_output_dic) == len(
163+
inputs
164+
), f"Expected {len(inputs)} results or outputs, but got {len(result_or_output_dic)}"
165+
166+
return [result_or_output_dic[i] for i in range(len(inputs))]
167+
168+
def _group_inputs(self, inputs):
169+
def _to_hashable(obj):
170+
if isinstance(obj, list):
171+
return tuple(obj)
172+
elif isinstance(obj, dict):
173+
return tuple(sorted(obj.items()))
174+
else:
175+
return obj
176+
177+
def _hash(input):
178+
return hash(
179+
tuple(
180+
map(
181+
_to_hashable,
182+
(
183+
input.useDocOrientationClassify,
184+
input.useDocUnwarping,
185+
input.useTextlineOrientation,
186+
input.textDetLimitSideLen,
187+
input.textDetLimitType,
188+
input.textDetThresh,
189+
input.textDetBoxThresh,
190+
input.textDetUnclipRatio,
191+
input.textRecScoreThresh,
192+
input.returnWordBox,
193+
),
194+
)
195+
)
196+
)
197+
198+
groups = {}
199+
for i, inp in enumerate(inputs):
200+
group_key = _hash(inp)
201+
if group_key not in groups:
202+
groups[group_key] = []
203+
groups[group_key].append((i, inp))
204+
205+
return list(groups.values())
206+
207+
def _preprocess(self, input, log_id):
71208
if input.fileType is None:
72209
if utils.is_url(input.file):
73210
maybe_file_type = utils.infer_file_type(input.file)
@@ -101,24 +238,11 @@ def run(self, input, log_id):
101238
max_num_imgs=self.context["max_num_input_imgs"],
102239
)
103240

104-
result = list(
105-
self.pipeline(
106-
images,
107-
use_doc_orientation_classify=input.useDocOrientationClassify,
108-
use_doc_unwarping=input.useDocUnwarping,
109-
use_textline_orientation=input.useTextlineOrientation,
110-
text_det_limit_side_len=input.textDetLimitSideLen,
111-
text_det_limit_type=input.textDetLimitType,
112-
text_det_thresh=input.textDetThresh,
113-
text_det_box_thresh=input.textDetBoxThresh,
114-
text_det_unclip_ratio=input.textDetUnclipRatio,
115-
text_rec_score_thresh=input.textRecScoreThresh,
116-
return_word_box=input.returnWordBox,
117-
)
118-
)
241+
return images, data_info, visualize_enabled
119242

243+
def _postprocess(self, images, data_info, visualize_enabled, preds, log_id, input):
120244
ocr_results: List[Dict[str, Any]] = []
121-
for i, (img, item) in enumerate(zip(images, result)):
245+
for i, (img, item) in enumerate(zip(images, preds)):
122246
pruned_res = app_common.prune_result(item.json["res"])
123247
if visualize_enabled:
124248
output_imgs = item.img
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
backend: "python"
2+
max_batch_size: 8
3+
input [
4+
{
5+
name: "input"
6+
data_type: TYPE_STRING
7+
dims: [ -1 ]
8+
}
9+
]
10+
output [
11+
{
12+
name: "output"
13+
data_type: TYPE_STRING
14+
dims: [ -1 ]
15+
}
16+
]
17+
instance_group [
18+
{
19+
count: 1
20+
kind: KIND_CPU
21+
}
22+
]
23+
dynamic_batching { }
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
backend: "python"
2+
max_batch_size: 8
3+
input [
4+
{
5+
name: "input"
6+
data_type: TYPE_STRING
7+
dims: [ -1 ]
8+
}
9+
]
10+
output [
11+
{
12+
name: "output"
13+
data_type: TYPE_STRING
14+
dims: [ -1 ]
15+
}
16+
]
17+
instance_group [
18+
{
19+
count: 1
20+
kind: KIND_GPU
21+
gpus: [ 0 ]
22+
}
23+
]
24+
dynamic_batching { }

0 commit comments

Comments
 (0)