Skip to content

Commit b8de017

Browse files
authored
Merge pull request #116 from yuege969/main
Add support for engine torch gpu
2 parents 1d26ad5 + ff867a5 commit b8de017

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

rapid_table/main.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
LoadImage,
1616
Logger,
1717
ModelType,
18+
EngineType,
1819
RapidTableInput,
1920
RapidTableOutput,
2021
get_boxes_recs,
@@ -45,7 +46,21 @@ def __init__(self, cfg: Optional[RapidTableInput] = None):
4546

4647
def _init_ocr_engine(self):
4748
try:
48-
return import_package("rapidocr").RapidOCR()
49+
rapidocr_ = import_package("rapidocr")
50+
if rapidocr_ is None:
51+
raise ModuleNotFoundError("rapidocr package is not installed")
52+
if self.cfg.engine_type == EngineType.TORCH:
53+
EngineType_RapidOCR = rapidocr_.EngineType
54+
return rapidocr_.RapidOCR(
55+
params={
56+
"Det.engine_type": EngineType_RapidOCR.TORCH,
57+
"Cls.engine_type": EngineType_RapidOCR.TORCH,
58+
"Rec.engine_type": EngineType_RapidOCR.TORCH,
59+
"EngineConfig.torch.use_cuda": True, # Use torch GPU to infer
60+
"EngineConfig.torch.gpu_id": 0 # Specify GPU id
61+
}
62+
)
63+
return rapidocr_.RapidOCR()
4964
except ModuleNotFoundError:
5065
logger.warning("rapidocr package is not installed, only table rec")
5166
return None

0 commit comments

Comments
 (0)