Skip to content

Commit b008fba

Browse files
committed
chore: update files
1 parent a86fb73 commit b008fba

File tree

10 files changed

+93
-103
lines changed

10 files changed

+93
-103
lines changed

python/demo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
from rapidocr import RapidOCR, VisRes
77

8+
# from rapidocr_onnxruntime import RapidOCR, VisRes
9+
10+
811
# from rapidocr_paddle import RapidOCR, VisRes
912
# from rapidocr_openvino import RapidOCR, VisRes
1013

python/rapidocr/cal_rec_boxes/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __call__(
2828
for idx, (img, box) in enumerate(zip(imgs, dt_boxes)):
2929
direction = self.get_box_direction(box)
3030

31-
rec_txt, rec_conf = rec_res.line_results[idx]
31+
rec_txt = rec_res.line_txts[idx]
3232
rec_word_info = rec_res.word_results[idx]
3333

3434
h, w = img.shape[:2]
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# -*- encoding: utf-8 -*-
22
# @Author: SWHL
33
# @Contact: liekkaskono@163.com
4-
from .text_cls import TextClassifier
4+
from .main import TextClassifier
5+
from .utils import TextClsOutput
Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
import copy
1616
import math
1717
import time
18-
from typing import Any, Dict, List, Tuple, Union
18+
from typing import Any, Dict, List, Union
1919

2020
import cv2
2121
import numpy as np
22+
2223
from rapidocr.utils import OrtInferSession, read_yaml
2324

24-
from .utils import ClsPostProcess
25+
from .utils import ClsPostProcess, TextClsOutput
2526

2627

2728
class TextClassifier:
@@ -33,9 +34,7 @@ def __init__(self, config: Dict[str, Any]):
3334

3435
self.infer = OrtInferSession(config)
3536

36-
def __call__(
37-
self, img_list: Union[np.ndarray, List[np.ndarray]]
38-
) -> Tuple[List[np.ndarray], List[List[Union[str, float]]], float]:
37+
def __call__(self, img_list: Union[np.ndarray, List[np.ndarray]]) -> TextClsOutput:
3938
if isinstance(img_list, np.ndarray):
4039
img_list = [img_list]
4140

@@ -48,7 +47,7 @@ def __call__(
4847
indices = np.argsort(np.array(width_list))
4948

5049
img_num = len(img_list)
51-
cls_res = [["", 0.0]] * img_num
50+
cls_res = [("", 0.0)] * img_num
5251
batch_num = self.cls_batch_num
5352
elapse = 0
5453
for beg_img_no in range(0, img_num, batch_num):
@@ -67,12 +66,12 @@ def __call__(
6766
elapse += time.time() - starttime
6867

6968
for rno, (label, score) in enumerate(cls_result):
70-
cls_res[indices[beg_img_no + rno]] = [label, score]
69+
cls_res[indices[beg_img_no + rno]] = (label, score)
7170
if "180" in label and score > self.cls_thresh:
7271
img_list[indices[beg_img_no + rno]] = cv2.rotate(
7372
img_list[indices[beg_img_no + rno]], 1
7473
)
75-
return img_list, cls_res, elapse
74+
return TextClsOutput(img_list=img_list, cls_res=cls_res, elapse=elapse)
7675

7776
def resize_norm_img(self, img: np.ndarray) -> np.ndarray:
7877
img_c, img_h, img_w = self.cls_image_shape

python/rapidocr/ch_ppocr_cls/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,19 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import List, Tuple
14+
from dataclasses import dataclass
15+
from typing import List, Optional, Tuple
1516

1617
import numpy as np
1718

1819

20+
@dataclass
21+
class TextClsOutput:
22+
img_list: Optional[List[np.ndarray]] = None
23+
cls_res: Optional[List[Tuple[str, float]]] = None
24+
elapse: Optional[float] = None
25+
26+
1927
class ClsPostProcess:
2028
def __init__(self, label_list: List[str]):
2129
self.label_list = label_list

python/rapidocr/ch_ppocr_rec/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
# @Author: SWHL
33
# @Contact: liekkaskono@163.com
44
from .main import TextRecognizer
5-
from .utils import TextRecArguments, TextRecOutput
5+
from .utils import TextRecInput, TextRecOutput

python/rapidocr/ch_ppocr_rec/main.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from rapidocr.utils import OrtInferSession
2222

23-
from .utils import CTCLabelDecode, TextRecArguments, TextRecOutput
23+
from .utils import CTCLabelDecode, TextRecInput, TextRecOutput
2424

2525

2626
class TextRecognizer:
@@ -39,7 +39,7 @@ def __init__(self, config: Dict[str, Any]):
3939
self.rec_batch_num = config["rec_batch_num"]
4040
self.rec_image_shape = config["rec_img_shape"]
4141

42-
def __call__(self, args: TextRecArguments) -> TextRecOutput:
42+
def __call__(self, args: TextRecInput) -> TextRecOutput:
4343
img_list = [args.img] if isinstance(args.img, np.ndarray) else args.img
4444
return_word_box = args.return_word_box
4545

@@ -90,7 +90,8 @@ def __call__(self, args: TextRecArguments) -> TextRecOutput:
9090
elapse += time.perf_counter() - start_time
9191

9292
all_line_results, all_word_results = list(zip(*rec_res))
93-
return TextRecOutput(all_line_results, all_word_results, elapse)
93+
line_txts, line_scores = list(zip(*all_line_results))
94+
return TextRecOutput(line_txts, line_scores, all_word_results, elapse)
9495

9596
def resize_norm_img(self, img: np.ndarray, max_wh_ratio: float) -> np.ndarray:
9697
img_channel, img_height, img_width = self.rec_image_shape

python/rapidocr/ch_ppocr_rec/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,15 @@ class TextRecConfig:
2222

2323

2424
@dataclass
25-
class TextRecArguments:
25+
class TextRecInput:
2626
img: Union[np.ndarray, List[np.ndarray], None] = None
2727
return_word_box: bool = False
2828

2929

3030
@dataclass
3131
class TextRecOutput:
32-
line_results: Tuple[Tuple[str, float]] = (("", 1.0),)
32+
line_txts: Optional[Tuple[str]] = None
33+
line_scores: Tuple[float] = (1.0,)
3334
word_results: Tuple[Tuple[str, float, Optional[List[List[int]]]]] = (
3435
("", 1.0, None),
3536
)

python/rapidocr/main.py

Lines changed: 50 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
import numpy as np
1010

1111
from .cal_rec_boxes import CalRecBoxes
12-
from .ch_ppocr_cls import TextClassifier
13-
from .ch_ppocr_det import TextDetector
14-
from .ch_ppocr_rec import TextRecArguments, TextRecognizer, TextRecOutput
12+
from .ch_ppocr_cls import TextClassifier, TextClsOutput
13+
from .ch_ppocr_det import TextDetector, TextDetOutput
14+
from .ch_ppocr_rec import TextRecInput, TextRecognizer, TextRecOutput
1515
from .utils import (
1616
LoadImage,
1717
RapidOCROutput,
@@ -71,7 +71,7 @@ def __call__(
7171
use_cls: Optional[bool] = None,
7272
use_rec: Optional[bool] = None,
7373
**kwargs,
74-
) -> Tuple[Optional[List[List[Union[Any, str]]]], Optional[List[float]]]:
74+
) -> RapidOCROutput:
7575
use_det = self.use_det if use_det is None else use_det
7676
use_cls = self.use_cls if use_cls is None else use_cls
7777
use_rec = self.use_rec if use_rec is None else use_rec
@@ -94,30 +94,30 @@ def __call__(
9494
img, ratio_h, ratio_w = self.preprocess(img)
9595
op_record["preprocess"] = {"ratio_h": ratio_h, "ratio_w": ratio_w}
9696

97-
dt_boxes, cls_res, rec_res = None, None, TextRecOutput()
98-
det_elapse, cls_elapse = 0.0, 0.0
97+
det_res, cls_res, rec_res = TextDetOutput(), TextClsOutput(), TextRecOutput()
9998

10099
if use_det:
101100
img, op_record = self.maybe_add_letterbox(img, op_record)
102101
det_res = self.text_det(img)
103102
if det_res.boxes is None:
104103
return RapidOCROutput()
105104

106-
img = self.get_crop_img_list(img, dt_boxes)
105+
img = self.get_crop_img_list(img, det_res)
107106

108107
if use_cls:
109-
img, cls_res, cls_elapse = self.text_cls(img)
108+
cls_res = self.text_cls(img)
109+
img = cls_res.img_list
110110

111111
if use_rec:
112-
rec_input = TextRecArguments(img=img, return_word_box=return_word_box)
112+
rec_input = TextRecInput(img=img, return_word_box=return_word_box)
113113
rec_res = self.text_rec(rec_input)
114114

115115
if (
116116
return_word_box
117-
and dt_boxes is not None
117+
and det_res.boxes is not None
118118
and all(v for v in rec_res.word_results)
119119
):
120-
rec_res = self.cal_rec_boxes(img, dt_boxes, rec_res)
120+
rec_res = self.cal_rec_boxes(img, det_res.boxes, rec_res)
121121
origin_words = []
122122
for one_word in rec_res.word_results:
123123
one_word_points = one_word[2]
@@ -131,10 +131,12 @@ def __call__(
131131
origin_words.append((one_word[0], one_word[1], origin_words_points))
132132
rec_res.word_results = tuple(origin_words)
133133

134-
if dt_boxes is not None:
135-
dt_boxes = self._get_origin_points(dt_boxes, op_record, raw_h, raw_w)
134+
if det_res.boxes is not None:
135+
det_res.boxes = self._get_origin_points(
136+
det_res.boxes, op_record, raw_h, raw_w
137+
)
136138

137-
ocr_res = self.get_final_res(dt_boxes, cls_res, rec_res, det_elapse, cls_elapse)
139+
ocr_res = self.get_final_res(det_res, cls_res, rec_res)
138140
return ocr_res
139141

140142
def preprocess(self, img: np.ndarray) -> Tuple[np.ndarray, float, float]:
@@ -174,18 +176,8 @@ def _get_padding_h(self, h: int, w: int) -> int:
174176
padding_h = int(abs(new_h - h) / 2)
175177
return padding_h
176178

177-
def auto_text_det(
178-
self, img: np.ndarray
179-
) -> Tuple[Optional[List[np.ndarray]], float]:
180-
dt_boxes, det_elapse = self.text_det(img)
181-
if dt_boxes is None or len(dt_boxes) < 1:
182-
return None, 0.0
183-
184-
dt_boxes = self.sorted_boxes(dt_boxes)
185-
return dt_boxes, det_elapse
186-
187179
def get_crop_img_list(
188-
self, img: np.ndarray, dt_boxes: List[np.ndarray]
180+
self, img: np.ndarray, det_res: TextDetOutput
189181
) -> List[np.ndarray]:
190182
def get_rotate_crop_image(img: np.ndarray, points: np.ndarray) -> np.ndarray:
191183
img_crop_width = int(
@@ -222,38 +214,12 @@ def get_rotate_crop_image(img: np.ndarray, points: np.ndarray) -> np.ndarray:
222214
return dst_img
223215

224216
img_crop_list = []
225-
for box in dt_boxes:
217+
for box in det_res.boxes:
226218
tmp_box = copy.deepcopy(box)
227219
img_crop = get_rotate_crop_image(img, tmp_box)
228220
img_crop_list.append(img_crop)
229221
return img_crop_list
230222

231-
@staticmethod
232-
def sorted_boxes(dt_boxes: np.ndarray) -> List[np.ndarray]:
233-
"""
234-
Sort text boxes in order from top to bottom, left to right
235-
args:
236-
dt_boxes(array):detected text boxes with shape [4, 2]
237-
return:
238-
sorted boxes(array) with shape [4, 2]
239-
"""
240-
num_boxes = dt_boxes.shape[0]
241-
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
242-
_boxes = list(sorted_boxes)
243-
244-
for i in range(num_boxes - 1):
245-
for j in range(i, -1, -1):
246-
if (
247-
abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10
248-
and _boxes[j + 1][0][0] < _boxes[j][0][0]
249-
):
250-
tmp = _boxes[j]
251-
_boxes[j] = _boxes[j + 1]
252-
_boxes[j + 1] = tmp
253-
else:
254-
break
255-
return _boxes
256-
257223
def _get_origin_points(
258224
self,
259225
dt_boxes: List[np.ndarray],
@@ -284,28 +250,33 @@ def _get_origin_points(
284250
return dt_boxes_array
285251

286252
def get_final_res(
287-
self,
288-
dt_boxes: Optional[List[np.ndarray]],
289-
cls_res: Optional[List[List[Union[str, float]]]],
290-
rec_res: TextRecOutput,
291-
det_elapse: float,
292-
cls_elapse: float,
293-
) -> Tuple[Optional[List[List[Union[Any, str]]]], Optional[List[float]]]:
294-
if dt_boxes is None and rec_res is None and cls_res is not None:
295-
return cls_res, [cls_elapse]
253+
self, det_res: TextDetOutput, cls_res: TextClsOutput, rec_res: TextRecOutput
254+
) -> Union[TextDetOutput, TextClsOutput, TextRecOutput, RapidOCROutput]:
255+
dt_boxes = det_res.boxes
256+
txt_res = rec_res.line_txts
296257

297-
if dt_boxes is None and rec_res is None:
298-
return None, None
258+
if dt_boxes is None and txt_res is None and cls_res.cls_res is not None:
259+
return cls_res
299260

300-
if dt_boxes is None and rec_res is not None:
301-
return [[v[0], v[1]] for v in rec_res.line_results], [rec_res.elapse]
261+
if dt_boxes is None and txt_res is None:
262+
return RapidOCROutput()
302263

303-
if dt_boxes is not None and rec_res is None:
304-
return [box.tolist() for box in dt_boxes], [det_elapse]
264+
if dt_boxes is None and txt_res is not None:
265+
return rec_res
305266

306-
dt_boxes, rec_res = self.filter_result(dt_boxes, rec_res)
307-
if not dt_boxes or not rec_res or len(dt_boxes) <= 0:
308-
return None, None
267+
if dt_boxes is not None and rec_res is None:
268+
return det_res
269+
270+
ocr_res = RapidOCROutput(
271+
boxes=det_res.boxes,
272+
txts=rec_res.line_txts,
273+
scores=rec_res.line_scores,
274+
word_results=rec_res.word_results,
275+
elapse_list=[det_res.elapse, cls_res.elapse, rec_res.elapse],
276+
)
277+
ocr_res = self.filter_by_text_score(ocr_res)
278+
if len(ocr_res.boxes) <= 0:
279+
return RapidOCROutput()
309280

310281
ocr_res = [
311282
[box.tolist(), *res] for box, res in zip(dt_boxes, rec_res.line_results)
@@ -316,23 +287,18 @@ def get_final_res(
316287
]
317288
return ocr_res
318289

319-
def filter_result(
320-
self,
321-
dt_boxes: Optional[List[np.ndarray]],
322-
rec_res: Optional[List[Tuple[str, float]]],
323-
) -> Tuple[Optional[List[np.ndarray]], TextRecOutput]:
324-
if dt_boxes is None or rec_res is None:
325-
return None, None
326-
327-
filter_boxes, filter_rec_res = [], []
328-
for box, rec_reuslt in zip(dt_boxes, rec_res.line_results):
329-
text, score = rec_reuslt[0], rec_reuslt[1]
290+
def filter_by_text_score(self, ocr_res: RapidOCROutput) -> RapidOCROutput:
291+
filter_boxes, filter_txts, filter_scores = [], [], []
292+
for box, txt, score in zip(ocr_res.boxes, ocr_res.txts, ocr_res.scores):
330293
if float(score) >= self.text_score:
331294
filter_boxes.append(box)
332-
filter_rec_res.append(rec_reuslt)
295+
filter_txts.append(txt)
296+
filter_scores.append(score)
333297

334-
rec_res.line_results = filter_rec_res
335-
return filter_boxes, rec_res
298+
ocr_res.boxes = filter_boxes
299+
ocr_res.txts = filter_txts
300+
ocr_res.scores = filter_scores
301+
return ocr_res
336302

337303

338304
def main():

python/rapidocr/utils/typings.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
11
# -*- encoding: utf-8 -*-
22
# @Author: SWHL
33
# @Contact: liekkaskono@163.com
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
5+
from typing import List, Optional, Tuple
56

67

78
@dataclass
89
class RapidOCROutput:
9-
pass
10+
boxes: Optional[List[List[float]]] = None
11+
txts: Optional[List[str]] = None
12+
scores: Optional[List[float]] = None
13+
word_results: Tuple[Tuple[str, float, Optional[List[List[int]]]]] = (
14+
("", 1.0, None),
15+
)
16+
elapse_list: List[float] = field(default_factory=list)
17+
elapse: float = field(init=False)
18+
19+
def __post_init__(self):
20+
self.elapse = sum(self.elapse_list)

0 commit comments

Comments
 (0)