Skip to content

Commit b7bd220

Browse files
committed
fix(rapidocr): fixed issue #498
1 parent 57c67c1 commit b7bd220

File tree

3 files changed

+35
-38
lines changed

3 files changed

+35
-38
lines changed

python/rapidocr/main.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ class RapidOCR:
3535
def __init__(
3636
self, config_path: Optional[str] = None, params: Optional[Dict[str, Any]] = None
3737
):
38-
cfg = self.load_config(config_path, params)
39-
self.initialize(cfg)
38+
cfg = self._load_config(config_path, params)
39+
self._initialize(cfg)
4040

4141
self.logger = Logger(logger_name=__name__).get_log()
4242

43-
def load_config(
43+
def _load_config(
4444
self, config_path: Optional[str], params: Optional[Dict[str, Any]]
4545
) -> DictConfig:
4646
if config_path is not None and Path(config_path).exists():
@@ -52,7 +52,7 @@ def load_config(
5252
cfg = ParseParams.update_batch(cfg, params)
5353
return cfg
5454

55-
def initialize(self, cfg: DictConfig):
55+
def _initialize(self, cfg: DictConfig):
5656
self.text_score = cfg.Global.text_score
5757
self.min_height = cfg.Global.min_height
5858
self.width_height_ratio = cfg.Global.width_height_ratio
@@ -273,7 +273,11 @@ def get_final_res(
273273
scores=rec_res.scores,
274274
word_results=rec_res.word_results,
275275
elapse_list=[det_res.elapse, cls_res.elapse, rec_res.elapse],
276-
lang_type=self.cfg.Rec.lang_type,
276+
viser=VisRes(
277+
text_score=self.cfg.Global.text_score,
278+
lang_type=self.cfg.Rec.lang_type,
279+
font_path=self.cfg.Global.font_path,
280+
),
277281
)
278282
ocr_res = self.filter_by_text_score(ocr_res)
279283
if len(ocr_res) <= 0:
@@ -409,11 +413,12 @@ def main(arg_list: Optional[List[str]] = None):
409413
save_path = cur_dir / f"{Path(args.img_path).stem}_vis_single.png"
410414
cv2.imwrite(str(save_path), vis_img)
411415
print(f"The vis single result has saved in {save_path}")
412-
else:
413-
save_path = cur_dir / f"{Path(args.img_path).stem}_vis.png"
414-
vis_img = vis(args.img_path, result.boxes, result.txts, result.scores)
415-
cv2.imwrite(str(save_path), vis_img)
416-
print(f"The vis result has saved in {save_path}")
416+
return
417+
418+
save_path = cur_dir / f"{Path(args.img_path).stem}_vis.png"
419+
vis_img = vis(args.img_path, result.boxes, result.txts, result.scores)
420+
cv2.imwrite(str(save_path), vis_img)
421+
print(f"The vis result has saved in {save_path}")
417422

418423

419424
if __name__ == "__main__":

python/rapidocr/utils/output.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class RapidOCROutput:
2525
)
2626
elapse_list: List[Union[float, None]] = field(default_factory=list)
2727
elapse: float = field(init=False)
28-
lang_type: Optional[str] = None
28+
viser: Optional[VisRes] = None
2929

3030
def __post_init__(self):
3131
self.elapse = sum(v for v in self.elapse_list if isinstance(v, float))
@@ -41,21 +41,17 @@ def to_json(self):
4141
def to_markdown(self) -> str:
4242
return ToMarkdown.to(self.boxes, self.txts)
4343

44-
def vis(self, save_path: Optional[str] = None, font_path: Optional[str] = None):
44+
def vis(self, save_path: Optional[str] = None):
4545
if self.img is None or self.boxes is None:
4646
logger.warning("No image or boxes to visualize.")
4747
return
4848

49-
vis = VisRes()
49+
if self.viser is None:
50+
logger.error("vis instance is None")
51+
return
52+
5053
if all(v is None for v in self.word_results):
51-
vis_img = vis(
52-
self.img,
53-
self.boxes,
54-
self.txts,
55-
self.scores,
56-
font_path=font_path,
57-
lang_type=self.lang_type,
58-
)
54+
vis_img = self.viser(self.img, self.boxes, self.txts, self.scores)
5955

6056
if save_path is not None:
6157
save_img(save_path, vis_img)
@@ -65,14 +61,7 @@ def vis(self, save_path: Optional[str] = None, font_path: Optional[str] = None):
6561
# single word vis
6662
words_results = sum(self.word_results, ())
6763
words, words_scores, words_boxes = list(zip(*words_results))
68-
vis_img = vis(
69-
self.img,
70-
words_boxes,
71-
words,
72-
words_scores,
73-
font_path=font_path,
74-
lang_type=self.lang_type,
75-
)
64+
vis_img = self.viser(self.img, words_boxes, words, words_scores)
7665

7766
if save_path is not None:
7867
save_img(save_path, vis_img)

python/rapidocr/utils/vis_res.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,28 +25,32 @@
2525

2626

2727
class VisRes:
28-
def __init__(self, text_score: float = 0.5):
28+
def __init__(
29+
self,
30+
text_score: float = 0.5,
31+
lang_type: Optional[LangRec] = None,
32+
font_path: Optional[str] = None,
33+
):
2934
self.logger = Logger(logger_name=__name__).get_log()
3035

3136
self.text_score = text_score
3237
self.load_img = LoadImage()
3338

3439
self.font_cfg = OmegaConf.load(FONT_YAML_PATH).fonts
3540

41+
self.font_path = self.get_font_path(font_path, lang_type)
42+
self.logger.info(f"Using {self.font_path} to visualize results.")
43+
3644
def __call__(
3745
self,
3846
img_content: InputType,
3947
dt_boxes: np.ndarray,
4048
txts: Optional[Union[List[str], Tuple[str]]] = None,
4149
scores: Optional[Tuple[float]] = None,
42-
font_path: Optional[str] = None,
43-
lang_type: Optional[LangRec] = None,
4450
) -> np.ndarray:
4551
if txts is None:
4652
return self.draw_dt_boxes(img_content, dt_boxes, scores)
47-
48-
font_path = self.get_font_path(font_path, lang_type)
49-
return self.draw_ocr_box_txt(img_content, dt_boxes, txts, font_path, scores)
53+
return self.draw_ocr_box_txt(img_content, dt_boxes, txts, scores)
5054

5155
def draw_dt_boxes(
5256
self,
@@ -180,7 +184,6 @@ def draw_ocr_box_txt(
180184
img_content: InputType,
181185
dt_boxes: np.ndarray,
182186
txts: Union[List[str], Tuple[str]],
183-
font_path: str,
184187
scores: Optional[Tuple[float]] = None,
185188
) -> np.ndarray:
186189
image = Image.fromarray(self.load_img(img_content))
@@ -208,7 +211,7 @@ def draw_ocr_box_txt(
208211
box_width = self.get_box_width(box)
209212
if box_height > 2 * box_width:
210213
font_size = max(int(box_width * 0.9), 10)
211-
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
214+
font = ImageFont.truetype(self.font_path, font_size, encoding="utf-8")
212215
cur_y = box[0][1]
213216

214217
for c in txt:
@@ -218,7 +221,7 @@ def draw_ocr_box_txt(
218221
cur_y += self.get_char_size(font, c)
219222
else:
220223
font_size = max(int(box_height * 0.8), 10)
221-
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
224+
font = ImageFont.truetype(self.font_path, font_size, encoding="utf-8")
222225
draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
223226

224227
img_left = Image.blend(image, img_left, 0.5)

0 commit comments

Comments
 (0)