1717 ModelType ,
1818 RapidTableInput ,
1919 RapidTableOutput ,
20+ get_boxes_recs ,
2021 import_package ,
2122)
2223
@@ -34,7 +35,11 @@ def __init__(self, cfg: Optional[RapidTableInput] = None):
3435
3536 self .cfg = cfg
3637 self .table_structure = self ._init_table_structer ()
37- self .ocr_engine = self ._init_ocr_engine ()
38+
39+ self .ocr_engine = None
40+ if cfg .use_ocr :
41+ self .ocr_engine = self ._init_ocr_engine ()
42+
3843 self .table_matcher = TableMatch ()
3944 self .load_img = LoadImage ()
4045
@@ -58,72 +63,48 @@ def _init_table_structer(self):
5863 def __call__ (
5964 self ,
6065 img_content : Union [str , np .ndarray , bytes , Path ],
61- ocr_result : Optional [List [ Union [ List [ List [ float ]], str , str ]]] = None ,
66+ ocr_results : Optional [Tuple [ np . ndarray , Tuple [ str ], Tuple [ float ]]] = None ,
6267 ) -> RapidTableOutput :
63- if self .ocr_engine is None and ocr_result is None :
64- raise ValueError (
65- "One of two conditions must be met: ocr_result is not empty, or rapidocr is installed."
66- )
68+ s = time .perf_counter ()
6769
6870 img = self .load_img (img_content )
6971
70- s = time .perf_counter ()
71- h , w = img .shape [:2 ]
72-
73- if ocr_result is None :
74- ocr_result = self .ocr_engine (img )
75- ocr_result = list (
76- zip (
77- ocr_result .boxes ,
78- ocr_result .txts ,
79- ocr_result .scores ,
80- )
81- )
82- dt_boxes , rec_res = self .get_boxes_recs (ocr_result , h , w )
72+ dt_boxes , rec_res = self .get_ocr_results (img , ocr_results )
73+ pred_structures , cell_bboxes , logic_points = self .get_table_rec_results (img )
74+ pred_html = self .get_table_matcher (
75+ pred_structures , cell_bboxes , dt_boxes , rec_res
76+ )
8377
84- pred_structures , cell_bboxes , _ = self .table_structure (img )
78+ elapse = time .perf_counter () - s
79+ return RapidTableOutput (img , pred_html , cell_bboxes , logic_points , elapse )
8580
86- # 适配slanet-plus模型输出的box缩放还原
87- if self .cfg .model_type == ModelType .SLANETPLUS :
88- cell_bboxes = self .adapt_slanet_plus (img , cell_bboxes )
81+ def get_ocr_results (
82+ self , img : np .ndarray , ocr_results : Tuple [np .ndarray , Tuple [str ], Tuple [float ]]
83+ ) -> Tuple [Optional [np .ndarray ], Optional [np .ndarray ]]:
84+ if ocr_results is not None :
85+ return get_boxes_recs (ocr_results , img .shape [:2 ])
8986
90- pred_html = self .table_matcher (pred_structures , cell_bboxes , dt_boxes , rec_res )
87+ if not self .cfg .use_ocr :
88+ return None , None
9189
92- # 过滤掉占位的bbox
93- mask = ~ np .all (cell_bboxes == 0 , axis = 1 )
94- cell_bboxes = cell_bboxes [mask ]
90+ ori_ocr_res = self .ocr_engine (img )
91+ if ori_ocr_res .boxes is None :
92+ logger .warning ("OCR Result is empty" )
93+ return None , None
9594
95+ ocr_results = [ori_ocr_res .boxes , ori_ocr_res .txts , ori_ocr_res .scores ]
96+ return get_boxes_recs (ocr_results , img .shape [:2 ])
97+
98+ def get_table_rec_results (self , img : np .ndarray ):
99+ pred_structures , cell_bboxes , _ = self .table_structure (img )
96100 logic_points = self .table_matcher .decode_logic_points (pred_structures )
97- elapse = time .perf_counter () - s
98- return RapidTableOutput (img , pred_html , cell_bboxes , logic_points , elapse )
101+ return pred_structures , cell_bboxes , logic_points
99102
100- def get_boxes_recs (
101- self , ocr_result : List [Union [List [List [float ]], str , str ]], h : int , w : int
102- ) -> Tuple [np .ndarray , Tuple [str , str ]]:
103- dt_boxes , rec_res , scores = list (zip (* ocr_result ))
104- rec_res = list (zip (rec_res , scores ))
105-
106- r_boxes = []
107- for box in dt_boxes :
108- box = np .array (box )
109- x_min = max (0 , box [:, 0 ].min () - 1 )
110- x_max = min (w , box [:, 0 ].max () + 1 )
111- y_min = max (0 , box [:, 1 ].min () - 1 )
112- y_max = min (h , box [:, 1 ].max () + 1 )
113- box = [x_min , y_min , x_max , y_max ]
114- r_boxes .append (box )
115- dt_boxes = np .array (r_boxes )
116- return dt_boxes , rec_res
117-
118- def adapt_slanet_plus (self , img : np .ndarray , cell_bboxes : np .ndarray ) -> np .ndarray :
119- h , w = img .shape [:2 ]
120- resized = 488
121- ratio = min (resized / h , resized / w )
122- w_ratio = resized / (w * ratio )
123- h_ratio = resized / (h * ratio )
124- cell_bboxes [:, 0 ::2 ] *= w_ratio
125- cell_bboxes [:, 1 ::2 ] *= h_ratio
126- return cell_bboxes
103+ def get_table_matcher (self , pred_structures , cell_bboxes , dt_boxes , rec_res ):
104+ if dt_boxes is None and rec_res is None :
105+ return None
106+
107+ return self .table_matcher (pred_structures , cell_bboxes , dt_boxes , rec_res )
127108
128109
129110def parse_args (arg_list : Optional [List [str ]] = None ):
@@ -158,11 +139,9 @@ def main(arg_list: Optional[List[str]] = None):
158139 if table_engine .ocr_engine is None :
159140 raise ValueError ("ocr engine is None" )
160141
161- rapid_ocr_output = table_engine .ocr_engine (img_path )
162- ocr_result = list (
163- zip (rapid_ocr_output .boxes , rapid_ocr_output .txts , rapid_ocr_output .scores )
164- )
165- table_results = table_engine (img_path , ocr_result )
142+ ori_ocr_res = table_engine .ocr_engine (img_path )
143+ ocr_results = [ori_ocr_res .boxes , ori_ocr_res .txts , ori_ocr_res .scores ]
144+ table_results = table_engine (img_path , ocr_results = ocr_results )
166145 print (table_results .pred_html )
167146
168147 if args .vis :
0 commit comments