1515from loguru import logger
1616
1717from magic_pdf .model .sub_modules .model_utils import get_vram
18-
18+ from magic_pdf . config . enums import SupportedPdfParseMethod
1919import magic_pdf .model as model_config
2020from magic_pdf .data .dataset import Dataset
2121from magic_pdf .libs .clean_memory import clean_memory
@@ -150,12 +150,13 @@ def doc_analyze(
150150 img_dict = page_data .get_image ()
151151 images .append (img_dict ['img' ])
152152 page_wh_list .append ((img_dict ['width' ], img_dict ['height' ]))
153+ images_with_extra_info = [(images [index ], ocr , dataset ._lang ) for index in range (len (dataset ))]
153154
154155 if len (images ) >= MIN_BATCH_INFERENCE_SIZE :
155156 batch_size = MIN_BATCH_INFERENCE_SIZE
156- batch_images = [images [i :i + batch_size ] for i in range (0 , len (images ), batch_size )]
157+ batch_images = [images_with_extra_info [i :i + batch_size ] for i in range (0 , len (images_with_extra_info ), batch_size )]
157158 else :
158- batch_images = [images ]
159+ batch_images = [images_with_extra_info ]
159160
160161 results = []
161162 for sn , batch_image in enumerate (batch_images ):
@@ -181,7 +182,7 @@ def doc_analyze(
181182
182183def batch_doc_analyze (
183184 datasets : list [Dataset ],
184- ocr : bool = False ,
185+ parse_method : str ,
185186 show_log : bool = False ,
186187 lang = None ,
187188 layout_model = None ,
@@ -192,47 +193,31 @@ def batch_doc_analyze(
192193 batch_size = MIN_BATCH_INFERENCE_SIZE
193194 images = []
194195 page_wh_list = []
195- lang_list = []
196- lang_s = set ()
196+
197+ images_with_extra_info = []
197198 for dataset in datasets :
198199 for index in range (len (dataset )):
199200 if lang is None or lang == 'auto' :
200- lang_list . append ( dataset ._lang )
201+ _lang = dataset ._lang
201202 else :
202- lang_list . append ( lang )
203- lang_s . add ( lang_list [ - 1 ])
203+ _lang = lang
204+
204205 page_data = dataset .get_page (index )
205206 img_dict = page_data .get_image ()
206207 images .append (img_dict ['img' ])
207208 page_wh_list .append ((img_dict ['width' ], img_dict ['height' ]))
209+ if parse_method == 'auto' :
210+ images_with_extra_info .append ((images [- 1 ], dataset .classify () == SupportedPdfParseMethod .OCR , _lang ))
211+ else :
212+ images_with_extra_info .append ((images [- 1 ], parse_method == 'ocr' , _lang ))
208213
209- batch_images = []
210- img_idx_list = []
211- for t_lang in lang_s :
212- tmp_img_idx_list = []
213- for i , _lang in enumerate (lang_list ):
214- if _lang == t_lang :
215- tmp_img_idx_list .append (i )
216- img_idx_list .extend (tmp_img_idx_list )
217-
218- if batch_size >= len (tmp_img_idx_list ):
219- batch_images .append ((t_lang , [images [j ] for j in tmp_img_idx_list ]))
220- else :
221- slices = [tmp_img_idx_list [k :k + batch_size ] for k in range (0 , len (tmp_img_idx_list ), batch_size )]
222- for arr in slices :
223- batch_images .append ((t_lang , [images [j ] for j in arr ]))
224-
225- unorder_results = []
226-
227- for sn , (_lang , batch_image ) in enumerate (batch_images ):
228- _ , result = may_batch_image_analyze (batch_image , sn , ocr , show_log , _lang , layout_model , formula_enable , table_enable )
229- unorder_results .extend (result )
230- results = [None ] * len (img_idx_list )
231- for i , idx in enumerate (img_idx_list ):
232- results [idx ] = unorder_results [i ]
214+ batch_images = [images_with_extra_info [i :i + batch_size ] for i in range (0 , len (images_with_extra_info ), batch_size )]
215+ results = []
216+ for sn , batch_image in enumerate (batch_images ):
217+ _ , result = may_batch_image_analyze (batch_image , sn , True , show_log , lang , layout_model , formula_enable , table_enable )
218+ results .extend (result )
233219
234220 infer_results = []
235-
236221 from magic_pdf .operators .models import InferenceResult
237222 for index in range (len (datasets )):
238223 dataset = datasets [index ]
@@ -248,9 +233,9 @@ def batch_doc_analyze(
248233
249234
250235def may_batch_image_analyze (
251- images : list [np .ndarray ],
236+ images_with_extra_info : list [( np .ndarray , bool , str ) ],
252237 idx : int ,
253- ocr : bool = False ,
238+ ocr : bool ,
254239 show_log : bool = False ,
255240 lang = None ,
256241 layout_model = None ,
@@ -267,6 +252,7 @@ def may_batch_image_analyze(
267252 ocr , show_log , lang , layout_model , formula_enable , table_enable
268253 )
269254
255+ images = [image for image , _ , _ in images_with_extra_info ]
270256 batch_analyze = False
271257 batch_ratio = 1
272258 device = get_device ()
@@ -306,8 +292,8 @@ def may_batch_image_analyze(
306292 images.append(img_dict['img'])
307293 page_wh_list.append((img_dict['width'], img_dict['height']))
308294 """
309- batch_model = BatchAnalyze (model = custom_model , batch_ratio = batch_ratio )
310- results = batch_model (images )
295+ batch_model = BatchAnalyze (model_manager , batch_ratio , show_log , layout_model , formula_enable , table_enable )
296+ results = batch_model (images_with_extra_info )
311297 """
312298 for index in range(len(dataset)):
313299 if start_page_id <= index <= end_page_id:
0 commit comments