Skip to content

Commit 7a871ef

Browse files
committed
fix: fixed issue #114
1 parent 51bb0ec commit 7a871ef

File tree

3 files changed

+20
-25
lines changed

3 files changed

+20
-25
lines changed

demo.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
# -*- encoding: utf-8 -*-
22
# @Author: SWHL
33
# @Contact: liekkaskono@163.com
4-
from rapidocr import RapidOCR
4+
from rapidocr import EngineType, RapidOCR
55

66
from rapid_table import ModelType, RapidTable, RapidTableInput
77

8-
ocr_engine = RapidOCR()
8+
ocr_engine = RapidOCR(
9+
params={
10+
"Det.engine_type": EngineType.TORCH,
11+
"Cls.engine_type": EngineType.TORCH,
12+
"Rec.engine_type": EngineType.TORCH,
13+
}
14+
)
915

1016
input_args = RapidTableInput(model_type=ModelType.UNITABLE)
1117
table_engine = RapidTable(input_args)

rapid_table/main.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
from dataclasses import asdict
77
from pathlib import Path
8-
from typing import List, Optional, Tuple, Union
8+
from typing import Any, Dict, List, Optional, Tuple, Union
99

1010
import numpy as np
1111

@@ -15,7 +15,6 @@
1515
LoadImage,
1616
Logger,
1717
ModelType,
18-
EngineType,
1918
RapidTableInput,
2019
RapidTableOutput,
2120
get_boxes_recs,
@@ -39,32 +38,21 @@ def __init__(self, cfg: Optional[RapidTableInput] = None):
3938

4039
self.ocr_engine = None
4140
if cfg.use_ocr:
42-
self.ocr_engine = self._init_ocr_engine()
41+
self.ocr_engine = self._init_ocr_engine(self.cfg.ocr_params)
4342

4443
self.table_matcher = TableMatch()
4544
self.load_img = LoadImage()
4645

47-
def _init_ocr_engine(self):
48-
try:
49-
rapidocr_ = import_package("rapidocr")
50-
if rapidocr_ is None:
51-
raise ModuleNotFoundError("rapidocr package is not installed")
52-
if self.cfg.engine_type == EngineType.TORCH:
53-
EngineType_RapidOCR = rapidocr_.EngineType
54-
return rapidocr_.RapidOCR(
55-
params={
56-
"Det.engine_type": EngineType_RapidOCR.TORCH,
57-
"Cls.engine_type": EngineType_RapidOCR.TORCH,
58-
"Rec.engine_type": EngineType_RapidOCR.TORCH,
59-
"EngineConfig.torch.use_cuda": True, # Use torch GPU to infer
60-
"EngineConfig.torch.gpu_id": 0 # Specify GPU id
61-
}
62-
)
63-
return rapidocr_.RapidOCR()
64-
except ModuleNotFoundError:
46+
def _init_ocr_engine(self, params: Dict[Any, Any]):
47+
rapidocr_ = import_package("rapidocr")
48+
if rapidocr_ is None:
6549
logger.warning("rapidocr package is not installed, only table rec")
6650
return None
6751

52+
if not params:
53+
return rapidocr_.RapidOCR()
54+
return rapidocr_.RapidOCR(params=params)
55+
6856
def _init_table_structer(self):
6957
if self.cfg.model_type == ModelType.UNITABLE:
7058
from .table_structure.unitable import UniTableStructure

rapid_table/utils/typings.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,12 @@ class RapidTableInput:
2929
model_type: Optional[ModelType] = ModelType.SLANETPLUS
3030
model_dir_or_path: Union[str, Path, None, Dict[str, str]] = None
3131

32-
use_ocr: bool = True
33-
3432
engine_type: Optional[EngineType] = None
3533
engine_cfg: dict = field(default_factory=dict)
3634

35+
use_ocr: bool = True
36+
ocr_params: dict = field(default_factory=dict)
37+
3738

3839
@dataclass
3940
class RapidTableOutput:

0 commit comments

Comments
 (0)