Skip to content

Commit f6bc4f7

Browse files
authored
Merge pull request #2003 from icecraft/feat/batch_analyze_with_ocr_and_lang
feat: batch inference with ocr and lang flag
2 parents 2c8470b + bbba2a1 commit f6bc4f7

File tree

3 files changed

+47
-60
lines changed

3 files changed

+47
-60
lines changed

magic_pdf/model/batch_analyze.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,25 @@
1717

1818

1919
class BatchAnalyze:
20-
def __init__(self, model: CustomPEKModel, batch_ratio: int):
21-
self.model = model
20+
def __init__(self, model_manager, batch_ratio: int, show_log, layout_model, formula_enable, table_enable):
21+
self.model_manager = model_manager
2222
self.batch_ratio = batch_ratio
23-
24-
def __call__(self, images: list) -> list:
23+
self.show_log = show_log
24+
self.layout_model = layout_model
25+
self.formula_enable = formula_enable
26+
self.table_enable = table_enable
27+
28+
def __call__(self, images_with_extra_info: list) -> list:
29+
if len(images_with_extra_info) == 0:
30+
return []
31+
2532
images_layout_res = []
2633
layout_start_time = time.time()
34+
_, fst_ocr, fst_lang = images_with_extra_info[0]
35+
self.model = self.model_manager.get_model(fst_ocr, self.show_log, fst_lang, self.layout_model, self.formula_enable, self.table_enable)
36+
37+
images = [image for image, _, _ in images_with_extra_info]
38+
2739
if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
2840
# layoutlmv3
2941
for image in images:
@@ -79,6 +91,8 @@ def __call__(self, images: list) -> list:
7991
table_count = 0
8092
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
8193
for index in range(len(images)):
94+
_, ocr_enable, _lang = images_with_extra_info[index]
95+
self.model = self.model_manager.get_model(ocr_enable, self.show_log, _lang, self.layout_model, self.formula_enable, self.table_enable)
8296
layout_res = images_layout_res[index]
8397
np_array_img = images[index]
8498

@@ -99,7 +113,7 @@ def __call__(self, images: list) -> list:
99113
# OCR recognition
100114
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
101115

102-
if self.model.apply_ocr:
116+
if ocr_enable:
103117
ocr_res = self.model.ocr_model.ocr(
104118
new_image, mfd_res=adjusted_mfdetrec_res
105119
)[0]
@@ -159,9 +173,7 @@ def __call__(self, images: list) -> list:
159173
table_count += len(table_res_list)
160174

161175
if self.model.apply_ocr:
162-
logger.info(f'ocr time: {round(ocr_time, 2)}, image num: {ocr_count}')
163-
else:
164-
logger.info(f'det time: {round(ocr_time, 2)}, image num: {ocr_count}')
176+
logger.info(f'det or det time costs: {round(ocr_time, 2)}, image num: {ocr_count}')
165177
if self.model.apply_table:
166178
logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
167179

magic_pdf/model/doc_analyze_by_custom_model.py

Lines changed: 24 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from loguru import logger
1616

1717
from magic_pdf.model.sub_modules.model_utils import get_vram
18-
18+
from magic_pdf.config.enums import SupportedPdfParseMethod
1919
import magic_pdf.model as model_config
2020
from magic_pdf.data.dataset import Dataset
2121
from magic_pdf.libs.clean_memory import clean_memory
@@ -150,12 +150,13 @@ def doc_analyze(
150150
img_dict = page_data.get_image()
151151
images.append(img_dict['img'])
152152
page_wh_list.append((img_dict['width'], img_dict['height']))
153+
images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(dataset))]
153154

154155
if len(images) >= MIN_BATCH_INFERENCE_SIZE:
155156
batch_size = MIN_BATCH_INFERENCE_SIZE
156-
batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
157+
batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
157158
else:
158-
batch_images = [images]
159+
batch_images = [images_with_extra_info]
159160

160161
results = []
161162
for sn, batch_image in enumerate(batch_images):
@@ -181,7 +182,7 @@ def doc_analyze(
181182

182183
def batch_doc_analyze(
183184
datasets: list[Dataset],
184-
ocr: bool = False,
185+
parse_method: str,
185186
show_log: bool = False,
186187
lang=None,
187188
layout_model=None,
@@ -192,47 +193,31 @@ def batch_doc_analyze(
192193
batch_size = MIN_BATCH_INFERENCE_SIZE
193194
images = []
194195
page_wh_list = []
195-
lang_list = []
196-
lang_s = set()
196+
197+
images_with_extra_info = []
197198
for dataset in datasets:
198199
for index in range(len(dataset)):
199200
if lang is None or lang == 'auto':
200-
lang_list.append(dataset._lang)
201+
_lang = dataset._lang
201202
else:
202-
lang_list.append(lang)
203-
lang_s.add(lang_list[-1])
203+
_lang = lang
204+
204205
page_data = dataset.get_page(index)
205206
img_dict = page_data.get_image()
206207
images.append(img_dict['img'])
207208
page_wh_list.append((img_dict['width'], img_dict['height']))
209+
if parse_method == 'auto':
210+
images_with_extra_info.append((images[-1], dataset.classify() == SupportedPdfParseMethod.OCR, _lang))
211+
else:
212+
images_with_extra_info.append((images[-1], parse_method == 'ocr', _lang))
208213

209-
batch_images = []
210-
img_idx_list = []
211-
for t_lang in lang_s:
212-
tmp_img_idx_list = []
213-
for i, _lang in enumerate(lang_list):
214-
if _lang == t_lang:
215-
tmp_img_idx_list.append(i)
216-
img_idx_list.extend(tmp_img_idx_list)
217-
218-
if batch_size >= len(tmp_img_idx_list):
219-
batch_images.append((t_lang, [images[j] for j in tmp_img_idx_list]))
220-
else:
221-
slices = [tmp_img_idx_list[k:k+batch_size] for k in range(0, len(tmp_img_idx_list), batch_size)]
222-
for arr in slices:
223-
batch_images.append((t_lang, [images[j] for j in arr]))
224-
225-
unorder_results = []
226-
227-
for sn, (_lang, batch_image) in enumerate(batch_images):
228-
_, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, _lang, layout_model, formula_enable, table_enable)
229-
unorder_results.extend(result)
230-
results = [None] * len(img_idx_list)
231-
for i, idx in enumerate(img_idx_list):
232-
results[idx] = unorder_results[i]
214+
batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
215+
results = []
216+
for sn, batch_image in enumerate(batch_images):
217+
_, result = may_batch_image_analyze(batch_image, sn, True, show_log, lang, layout_model, formula_enable, table_enable)
218+
results.extend(result)
233219

234220
infer_results = []
235-
236221
from magic_pdf.operators.models import InferenceResult
237222
for index in range(len(datasets)):
238223
dataset = datasets[index]
@@ -248,9 +233,9 @@ def batch_doc_analyze(
248233

249234

250235
def may_batch_image_analyze(
251-
images: list[np.ndarray],
236+
images_with_extra_info: list[(np.ndarray, bool, str)],
252237
idx: int,
253-
ocr: bool = False,
238+
ocr: bool,
254239
show_log: bool = False,
255240
lang=None,
256241
layout_model=None,
@@ -267,6 +252,7 @@ def may_batch_image_analyze(
267252
ocr, show_log, lang, layout_model, formula_enable, table_enable
268253
)
269254

255+
images = [image for image, _, _ in images_with_extra_info]
270256
batch_analyze = False
271257
batch_ratio = 1
272258
device = get_device()
@@ -306,8 +292,8 @@ def may_batch_image_analyze(
306292
images.append(img_dict['img'])
307293
page_wh_list.append((img_dict['width'], img_dict['height']))
308294
"""
309-
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
310-
results = batch_model(images)
295+
batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable)
296+
results = batch_model(images_with_extra_info)
311297
"""
312298
for index in range(len(dataset)):
313299
if start_page_id <= index <= end_page_id:

magic_pdf/tools/common.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -314,21 +314,10 @@ def batch_do_parse(
314314
dss.append(PymuDocDataset(v, lang=lang))
315315
else:
316316
dss.append(v)
317-
dss_with_fn = list(zip(dss, pdf_file_names))
318-
if parse_method == 'auto':
319-
dss_typed_txt = [(i, x) for i, x in enumerate(dss_with_fn) if x[0].classify() == SupportedPdfParseMethod.TXT]
320-
dss_typed_ocr = [(i, x) for i, x in enumerate(dss_with_fn) if x[0].classify() == SupportedPdfParseMethod.OCR]
321-
infer_results = [None] * len(dss_with_fn)
322-
infer_results_txt = batch_doc_analyze([x[1][0] for x in dss_typed_txt], lang=lang, ocr=False, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
323-
infer_results_ocr = batch_doc_analyze([x[1][0] for x in dss_typed_ocr], lang=lang, ocr=True, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
324-
for i, infer_res in enumerate(infer_results_txt):
325-
infer_results[dss_typed_txt[i][0]] = infer_res
326-
for i, infer_res in enumerate(infer_results_ocr):
327-
infer_results[dss_typed_ocr[i][0]] = infer_res
328-
else:
329-
infer_results = batch_doc_analyze(dss, lang=lang, ocr=parse_method == 'ocr', layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
317+
318+
infer_results = batch_doc_analyze(dss, parse_method, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
330319
for idx, infer_result in enumerate(infer_results):
331-
_do_parse(output_dir, dss_with_fn[idx][1], dss_with_fn[idx][0], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox, lang=lang)
320+
_do_parse(output_dir, pdf_file_names[idx], dss[idx], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox, lang=lang)
332321

333322

334323
parse_pdf_methods = click.Choice(['ocr', 'txt', 'auto'])

0 commit comments

Comments
 (0)