55import time
66from dataclasses import asdict
77from pathlib import Path
8- from typing import List , Optional , Tuple , Union
8+ from typing import Any , Dict , List , Optional , Tuple , Union
99
1010import numpy as np
1111
1515 LoadImage ,
1616 Logger ,
1717 ModelType ,
18- EngineType ,
1918 RapidTableInput ,
2019 RapidTableOutput ,
2120 get_boxes_recs ,
@@ -39,32 +38,21 @@ def __init__(self, cfg: Optional[RapidTableInput] = None):
3938
4039 self .ocr_engine = None
4140 if cfg .use_ocr :
42- self .ocr_engine = self ._init_ocr_engine ()
41+ self .ocr_engine = self ._init_ocr_engine (self . cfg . ocr_params )
4342
4443 self .table_matcher = TableMatch ()
4544 self .load_img = LoadImage ()
4645
47- def _init_ocr_engine (self ):
48- try :
49- rapidocr_ = import_package ("rapidocr" )
50- if rapidocr_ is None :
51- raise ModuleNotFoundError ("rapidocr package is not installed" )
52- if self .cfg .engine_type == EngineType .TORCH :
53- EngineType_RapidOCR = rapidocr_ .EngineType
54- return rapidocr_ .RapidOCR (
55- params = {
56- "Det.engine_type" : EngineType_RapidOCR .TORCH ,
57- "Cls.engine_type" : EngineType_RapidOCR .TORCH ,
58- "Rec.engine_type" : EngineType_RapidOCR .TORCH ,
59- "EngineConfig.torch.use_cuda" : True , # Use torch GPU to infer
60- "EngineConfig.torch.gpu_id" : 0 # Specify GPU id
61- }
62- )
63- return rapidocr_ .RapidOCR ()
64- except ModuleNotFoundError :
46+ def _init_ocr_engine (self , params : Dict [Any , Any ]):
47+ rapidocr_ = import_package ("rapidocr" )
48+ if rapidocr_ is None :
6549 logger .warning ("rapidocr package is not installed, only table rec" )
6650 return None
6751
52+ if not params :
53+ return rapidocr_ .RapidOCR ()
54+ return rapidocr_ .RapidOCR (params = params )
55+
6856 def _init_table_structer (self ):
6957 if self .cfg .model_type == ModelType .UNITABLE :
7058 from .table_structure .unitable import UniTableStructure
0 commit comments