Skip to content

Commit c886936

Browse files
authored
[Enhancement] decouple batch_size to det_batch_size, rec_batch_size and kie_batch_size in MMOCRInferencer (#1801)
* decouple batch_size to det_batch_size, rec_batch_size, kie_batch_size and chunk_size in MMOCRInferencer * remove chunk_size parameter * add Optional keyword in function definitions and doc strings * add det_batch_size, rec_batch_size, kie_batch_size in user_guides * minor formatting
1 parent 22f40b7 commit c886936

File tree

3 files changed

+55
-6
lines changed

3 files changed

+55
-6
lines changed

docs/en/user_guides/inference.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,9 @@ Here are extensive lists of parameters that you can use.
460460
| `inputs` | str/list/tuple/np.array | **required** | It can be a path to an image/a folder, an np array or a list/tuple (with img paths or np arrays) |
461461
| `return_datasamples` | bool | False | Whether to return results as DataSamples. If False, the results will be packed into a dict. |
462462
| `batch_size` | int | 1 | Inference batch size. |
463+
| `det_batch_size` | int, optional | None | Inference batch size for text detection model. Overwrite batch_size if it is not None. |
464+
| `rec_batch_size` | int, optional | None | Inference batch size for text recognition model. Overwrite batch_size if it is not None. |
465+
| `kie_batch_size` | int, optional | None | Inference batch size for KIE model. Overwrite batch_size if it is not None. |
463466
| `return_vis` | bool | False | Whether to return the visualization result. |
464467
| `print_result` | bool | False | Whether to print the inference result to the console. |
465468
| `show` | bool | False | Whether to display the visualization results in a popup window. |

docs/zh_cn/user_guides/inference.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,9 @@ outputs
457457
| `inputs` | str/list/tuple/np.array | **必需** | 它可以是一个图片/文件夹的路径,一个 numpy 数组,或者是一个包含图片路径或 numpy 数组的列表/元组 |
458458
| `return_datasamples` | bool | False | 是否将结果作为 DataSample 返回。如果为 False,结果将被打包成一个字典。 |
459459
| `batch_size` | int | 1 | 推理的批大小。 |
460+
| `det_batch_size` | int, 可选 | None | 推理的批大小 (文本检测模型)。如果不为 None,则覆盖 batch_size。 |
461+
| `rec_batch_size` | int, 可选 | None | 推理的批大小 (文本识别模型)。如果不为 None,则覆盖 batch_size。 |
462+
| `kie_batch_size` | int, 可选 | None | 推理的批大小 (关键信息提取模型)。如果不为 None,则覆盖 batch_size。 |
460463
| `return_vis` | bool | False | 是否返回可视化结果。 |
461464
| `print_result` | bool | False | 是否将推理结果打印到控制台。 |
462465
| `show` | bool | False | 是否在弹出窗口中显示可视化结果。 |

mmocr/apis/inferencers/mmocr_inferencer.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,34 +105,54 @@ def _inputs2ndarrray(self, inputs: List[InputsType]) -> List[np.ndarray]:
105105
'supported yet.')
106106
return new_inputs
107107

108-
def forward(self, inputs: InputsType, batch_size: int,
108+
def forward(self,
109+
inputs: InputsType,
110+
batch_size: int = 1,
111+
det_batch_size: Optional[int] = None,
112+
rec_batch_size: Optional[int] = None,
113+
kie_batch_size: Optional[int] = None,
109114
**forward_kwargs) -> PredType:
110115
"""Forward the inputs to the model.
111116
112117
Args:
113118
inputs (InputsType): The inputs to be forwarded.
114119
batch_size (int): Batch size. Defaults to 1.
120+
det_batch_size (Optional[int]): Batch size for text detection
121+
model. Overwrite batch_size if it is not None.
122+
Defaults to None.
123+
rec_batch_size (Optional[int]): Batch size for text recognition
124+
model. Overwrite batch_size if it is not None.
125+
Defaults to None.
126+
kie_batch_size (Optional[int]): Batch size for KIE model.
127+
Overwrite batch_size if it is not None.
128+
Defaults to None.
115129
116130
Returns:
117131
Dict: The prediction results. Possibly with keys "det", "rec", and
118132
"kie"..
119133
"""
120134
result = {}
121135
forward_kwargs['progress_bar'] = False
136+
if det_batch_size is None:
137+
det_batch_size = batch_size
138+
if rec_batch_size is None:
139+
rec_batch_size = batch_size
140+
if kie_batch_size is None:
141+
kie_batch_size = batch_size
122142
if self.mode == 'rec':
123143
# The extra list wrapper here is for the ease of postprocessing
124144
self.rec_inputs = inputs
125145
predictions = self.textrec_inferencer(
126146
self.rec_inputs,
127147
return_datasamples=True,
128-
batch_size=batch_size,
148+
batch_size=rec_batch_size,
129149
**forward_kwargs)['predictions']
130150
result['rec'] = [[p] for p in predictions]
131151
elif self.mode.startswith('det'): # 'det'/'det_rec'/'det_rec_kie'
132152
result['det'] = self.textdet_inferencer(
133153
inputs,
134154
return_datasamples=True,
135-
batch_size=batch_size,
155+
batch_size=det_batch_size,
136156
**forward_kwargs)['predictions']
137157
if self.mode.startswith('det_rec'): # 'det_rec'/'det_rec_kie'
138158
result['rec'] = []
@@ -149,7 +169,7 @@ def forward(self, inputs: InputsType, batch_size: int,
149169
self.textrec_inferencer(
150170
self.rec_inputs,
151171
return_datasamples=True,
152-
batch_size=batch_size,
172+
batch_size=rec_batch_size,
153173
**forward_kwargs)['predictions'])
154174
if self.mode == 'det_rec_kie':
155175
self.kie_inputs = []
@@ -172,7 +192,7 @@ def forward(self, inputs: InputsType, batch_size: int,
172192
result['kie'] = self.kie_inferencer(
173193
self.kie_inputs,
174194
return_datasamples=True,
175-
batch_size=batch_size,
195+
batch_size=kie_batch_size,
176196
**forward_kwargs)['predictions']
177197
return result
178198

@@ -219,6 +239,9 @@ def __call__(
219239
self,
220240
inputs: InputsType,
221241
batch_size: int = 1,
242+
det_batch_size: Optional[int] = None,
243+
rec_batch_size: Optional[int] = None,
244+
kie_batch_size: Optional[int] = None,
222245
out_dir: str = 'results/',
223246
return_vis: bool = False,
224247
save_vis: bool = False,
@@ -231,6 +254,15 @@ def __call__(
231254
inputs (InputsType): Inputs for the inferencer. It can be a path
232255
to image / image directory, or an array, or a list of these.
233256
batch_size (int): Batch size. Defaults to 1.
257+
det_batch_size (Optional[int]): Batch size for text detection
258+
model. Overwrite batch_size if it is not None.
259+
Defaults to None.
260+
rec_batch_size (Optional[int]): Batch size for text recognition
261+
model. Overwrite batch_size if it is not None.
262+
Defaults to None.
263+
kie_batch_size (Optional[int]): Batch size for KIE model.
264+
Overwrite batch_size if it is not None.
265+
Defaults to None.
234266
out_dir (str): Output directory of results. Defaults to 'results/'.
235267
return_vis (bool): Whether to return the visualization result.
236268
Defaults to False.
@@ -269,12 +301,23 @@ def __call__(
269301
**kwargs)
270302

271303
ori_inputs = self._inputs_to_list(inputs)
304+
if det_batch_size is None:
305+
det_batch_size = batch_size
306+
if rec_batch_size is None:
307+
rec_batch_size = batch_size
308+
if kie_batch_size is None:
309+
kie_batch_size = batch_size
272310

273311
chunked_inputs = super(BaseMMOCRInferencer,
274312
self)._get_chunk_data(ori_inputs, batch_size)
275313
results = {'predictions': [], 'visualization': []}
276314
for ori_input in track(chunked_inputs, description='Inference'):
277-
preds = self.forward(ori_input, batch_size, **forward_kwargs)
315+
preds = self.forward(
316+
ori_input,
317+
det_batch_size=det_batch_size,
318+
rec_batch_size=rec_batch_size,
319+
kie_batch_size=kie_batch_size,
320+
**forward_kwargs)
278321
visualization = self.visualize(
279322
ori_input, preds, img_out_dir=img_out_dir, **visualize_kwargs)
280323
batch_res = self.postprocess(

0 commit comments

Comments
 (0)