99import numpy as np
1010
1111from .cal_rec_boxes import CalRecBoxes
12- from .ch_ppocr_cls import TextClassifier
13- from .ch_ppocr_det import TextDetector
14- from .ch_ppocr_rec import TextRecArguments , TextRecognizer , TextRecOutput
12+ from .ch_ppocr_cls import TextClassifier , TextClsOutput
13+ from .ch_ppocr_det import TextDetector , TextDetOutput
14+ from .ch_ppocr_rec import TextRecInput , TextRecognizer , TextRecOutput
1515from .utils import (
1616 LoadImage ,
1717 RapidOCROutput ,
@@ -71,7 +71,7 @@ def __call__(
7171 use_cls : Optional [bool ] = None ,
7272 use_rec : Optional [bool ] = None ,
7373 ** kwargs ,
74- ) -> Tuple [ Optional [ List [ List [ Union [ Any , str ]]]], Optional [ List [ float ]]] :
74+ ) -> RapidOCROutput :
7575 use_det = self .use_det if use_det is None else use_det
7676 use_cls = self .use_cls if use_cls is None else use_cls
7777 use_rec = self .use_rec if use_rec is None else use_rec
@@ -94,30 +94,30 @@ def __call__(
9494 img , ratio_h , ratio_w = self .preprocess (img )
9595 op_record ["preprocess" ] = {"ratio_h" : ratio_h , "ratio_w" : ratio_w }
9696
97- dt_boxes , cls_res , rec_res = None , None , TextRecOutput ()
98- det_elapse , cls_elapse = 0.0 , 0.0
97+ det_res , cls_res , rec_res = TextDetOutput (), TextClsOutput (), TextRecOutput ()
9998
10099 if use_det :
101100 img , op_record = self .maybe_add_letterbox (img , op_record )
102101 det_res = self .text_det (img )
103102 if det_res .boxes is None :
104103 return RapidOCROutput ()
105104
106- img = self .get_crop_img_list (img , dt_boxes )
105+ img = self .get_crop_img_list (img , det_res )
107106
108107 if use_cls :
109- img , cls_res , cls_elapse = self .text_cls (img )
108+ cls_res = self .text_cls (img )
109+ img = cls_res .img_list
110110
111111 if use_rec :
112- rec_input = TextRecArguments (img = img , return_word_box = return_word_box )
112+ rec_input = TextRecInput (img = img , return_word_box = return_word_box )
113113 rec_res = self .text_rec (rec_input )
114114
115115 if (
116116 return_word_box
117- and dt_boxes is not None
117+ and det_res . boxes is not None
118118 and all (v for v in rec_res .word_results )
119119 ):
120- rec_res = self .cal_rec_boxes (img , dt_boxes , rec_res )
120+ rec_res = self .cal_rec_boxes (img , det_res . boxes , rec_res )
121121 origin_words = []
122122 for one_word in rec_res .word_results :
123123 one_word_points = one_word [2 ]
@@ -131,10 +131,12 @@ def __call__(
131131 origin_words .append ((one_word [0 ], one_word [1 ], origin_words_points ))
132132 rec_res .word_results = tuple (origin_words )
133133
134- if dt_boxes is not None :
135- dt_boxes = self ._get_origin_points (dt_boxes , op_record , raw_h , raw_w )
134+ if det_res .boxes is not None :
135+ det_res .boxes = self ._get_origin_points (
136+ det_res .boxes , op_record , raw_h , raw_w
137+ )
136138
137- ocr_res = self .get_final_res (dt_boxes , cls_res , rec_res , det_elapse , cls_elapse )
139+ ocr_res = self .get_final_res (det_res , cls_res , rec_res )
138140 return ocr_res
139141
140142 def preprocess (self , img : np .ndarray ) -> Tuple [np .ndarray , float , float ]:
@@ -174,18 +176,8 @@ def _get_padding_h(self, h: int, w: int) -> int:
174176 padding_h = int (abs (new_h - h ) / 2 )
175177 return padding_h
176178
177- def auto_text_det (
178- self , img : np .ndarray
179- ) -> Tuple [Optional [List [np .ndarray ]], float ]:
180- dt_boxes , det_elapse = self .text_det (img )
181- if dt_boxes is None or len (dt_boxes ) < 1 :
182- return None , 0.0
183-
184- dt_boxes = self .sorted_boxes (dt_boxes )
185- return dt_boxes , det_elapse
186-
187179 def get_crop_img_list (
188- self , img : np .ndarray , dt_boxes : List [ np . ndarray ]
180+ self , img : np .ndarray , det_res : TextDetOutput
189181 ) -> List [np .ndarray ]:
190182 def get_rotate_crop_image (img : np .ndarray , points : np .ndarray ) -> np .ndarray :
191183 img_crop_width = int (
@@ -222,38 +214,12 @@ def get_rotate_crop_image(img: np.ndarray, points: np.ndarray) -> np.ndarray:
222214 return dst_img
223215
224216 img_crop_list = []
225- for box in dt_boxes :
217+ for box in det_res . boxes :
226218 tmp_box = copy .deepcopy (box )
227219 img_crop = get_rotate_crop_image (img , tmp_box )
228220 img_crop_list .append (img_crop )
229221 return img_crop_list
230222
231- @staticmethod
232- def sorted_boxes (dt_boxes : np .ndarray ) -> List [np .ndarray ]:
233- """
234- Sort text boxes in order from top to bottom, left to right
235- args:
236- dt_boxes(array):detected text boxes with shape [4, 2]
237- return:
238- sorted boxes(array) with shape [4, 2]
239- """
240- num_boxes = dt_boxes .shape [0 ]
241- sorted_boxes = sorted (dt_boxes , key = lambda x : (x [0 ][1 ], x [0 ][0 ]))
242- _boxes = list (sorted_boxes )
243-
244- for i in range (num_boxes - 1 ):
245- for j in range (i , - 1 , - 1 ):
246- if (
247- abs (_boxes [j + 1 ][0 ][1 ] - _boxes [j ][0 ][1 ]) < 10
248- and _boxes [j + 1 ][0 ][0 ] < _boxes [j ][0 ][0 ]
249- ):
250- tmp = _boxes [j ]
251- _boxes [j ] = _boxes [j + 1 ]
252- _boxes [j + 1 ] = tmp
253- else :
254- break
255- return _boxes
256-
257223 def _get_origin_points (
258224 self ,
259225 dt_boxes : List [np .ndarray ],
@@ -284,28 +250,33 @@ def _get_origin_points(
284250 return dt_boxes_array
285251
286252 def get_final_res (
287- self ,
288- dt_boxes : Optional [List [np .ndarray ]],
289- cls_res : Optional [List [List [Union [str , float ]]]],
290- rec_res : TextRecOutput ,
291- det_elapse : float ,
292- cls_elapse : float ,
293- ) -> Tuple [Optional [List [List [Union [Any , str ]]]], Optional [List [float ]]]:
294- if dt_boxes is None and rec_res is None and cls_res is not None :
295- return cls_res , [cls_elapse ]
253+ self , det_res : TextDetOutput , cls_res : TextClsOutput , rec_res : TextRecOutput
254+ ) -> Union [TextDetOutput , TextClsOutput , TextRecOutput , RapidOCROutput ]:
255+ dt_boxes = det_res .boxes
256+ txt_res = rec_res .line_txts
296257
297- if dt_boxes is None and rec_res is None :
298- return None , None
258+ if dt_boxes is None and txt_res is None and cls_res . cls_res is not None :
259+ return cls_res
299260
300- if dt_boxes is None and rec_res is not None :
301- return [[ v [ 0 ], v [ 1 ]] for v in rec_res . line_results ], [ rec_res . elapse ]
261+ if dt_boxes is None and txt_res is None :
262+ return RapidOCROutput ()
302263
303- if dt_boxes is not None and rec_res is None :
304- return [ box . tolist () for box in dt_boxes ], [ det_elapse ]
264+ if dt_boxes is None and txt_res is not None :
265+ return rec_res
305266
306- dt_boxes , rec_res = self .filter_result (dt_boxes , rec_res )
307- if not dt_boxes or not rec_res or len (dt_boxes ) <= 0 :
308- return None , None
267+ if dt_boxes is not None and rec_res is None :
268+ return det_res
269+
270+ ocr_res = RapidOCROutput (
271+ boxes = det_res .boxes ,
272+ txts = rec_res .line_txts ,
273+ scores = rec_res .line_scores ,
274+ word_results = rec_res .word_results ,
275+ elapse_list = [det_res .elapse , cls_res .elapse , rec_res .elapse ],
276+ )
277+ ocr_res = self .filter_by_text_score (ocr_res )
278+ if len (ocr_res .boxes ) <= 0 :
279+ return RapidOCROutput ()
309280
310281 ocr_res = [
311282 [box .tolist (), * res ] for box , res in zip (dt_boxes , rec_res .line_results )
@@ -316,23 +287,18 @@ def get_final_res(
316287 ]
317288 return ocr_res
318289
319- def filter_result (
320- self ,
321- dt_boxes : Optional [List [np .ndarray ]],
322- rec_res : Optional [List [Tuple [str , float ]]],
323- ) -> Tuple [Optional [List [np .ndarray ]], TextRecOutput ]:
324- if dt_boxes is None or rec_res is None :
325- return None , None
326-
327- filter_boxes , filter_rec_res = [], []
328- for box , rec_reuslt in zip (dt_boxes , rec_res .line_results ):
329- text , score = rec_reuslt [0 ], rec_reuslt [1 ]
290+ def filter_by_text_score (self , ocr_res : RapidOCROutput ) -> RapidOCROutput :
291+ filter_boxes , filter_txts , filter_scores = [], [], []
292+ for box , txt , score in zip (ocr_res .boxes , ocr_res .txts , ocr_res .scores ):
330293 if float (score ) >= self .text_score :
331294 filter_boxes .append (box )
332- filter_rec_res .append (rec_reuslt )
295+ filter_txts .append (txt )
296+ filter_scores .append (score )
333297
334- rec_res .line_results = filter_rec_res
335- return filter_boxes , rec_res
298+ ocr_res .boxes = filter_boxes
299+ ocr_res .txts = filter_txts
300+ ocr_res .scores = filter_scores
301+ return ocr_res
336302
337303
338304def main ():
0 commit comments