Skip to content

Commit 38eab97

Browse files
authored
Merge pull request #99 from RapidAI/develop
feat: adapt rapidocr v3 and refactor code
2 parents b9adda2 + ab0936a commit 38eab97

35 files changed

+1398
-1114
lines changed

README.md

Lines changed: 106 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
<h1><b>📊 Rapid Table</b></h1>
44
</div>
55

6-
<a href="https://huggingface.co/spaces/Joker1212/TableDetAndRec" target="_blank"><img src="https://img.shields.io/badge/%F0%9F%A4%97-Online Demo-blue"></a>
6+
<a href="https://huggingface.co/spaces/RapidAI/TableStructureRec" target="_blank"><img src="https://img.shields.io/badge/%F0%9F%A4%97-Online Demo-blue"></a>
77
<a href="https://www.modelscope.cn/studios/RapidAI/TableRec/summary" target="_blank"><img src="https://img.shields.io/badge/魔搭-Demo-blue"></a>
88
<a href=""><img src="https://img.shields.io/badge/Python->=3.6-aff.svg"></a>
99
<a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Win%2C%20Mac-pink.svg"></a>
@@ -35,6 +35,57 @@ unitable是来源unitable的transformer模型,精度最高,暂仅支持pytor
3535
<img src="https://github.com/RapidAI/RapidTable/releases/download/assets/preview.gif" alt="Demo" width="80%" height="80%">
3636
</div>
3737

38+
### 🖥️ 支持设备
39+
40+
通过ONNXRuntime推理引擎支持:
41+
42+
- DirectML
43+
- 昇腾NPU
44+
45+
具体使用方法:
46+
47+
1. 安装(需要卸载其他onnxruntime):
48+
49+
```bash
50+
# DirectML
51+
pip install onnxruntime-directml
52+
53+
# 昇腾NPU
54+
pip install onnxruntime-cann
55+
```
56+
57+
2. 使用:
58+
59+
```python
60+
from rapidocr import RapidOCR
61+
62+
from rapid_table import ModelType, RapidTable, RapidTableInput
63+
64+
# DirectML
65+
ocr_engine = RapidOCR(params={"EngineConfig.onnxruntime.use_dml": True})
66+
input_args = RapidTableInput(
67+
model_type=ModelType.SLANETPLUS, engine_cfg={"use_dml": True}
68+
)
69+
70+
# 昇腾NPU
71+
ocr_engine = RapidOCR(params={"EngineConfig.onnxruntime.use_cann": True})
72+
73+
input_args = RapidTableInput(
74+
model_type=ModelType.SLANETPLUS,
75+
engine_cfg={"use_cann": True, "cann_ep_cfg.gpu_id": 1},
76+
)
77+
78+
table_engine = RapidTable(input_args)
79+
80+
img_path = "<https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg>"
81+
rapid_ocr_output = ocr_engine(img_path)
82+
ocr_result = list(
83+
zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
84+
)
85+
results = table_engine(img_path, ocr_result)
86+
results.vis(save_dir="outputs", save_name="vis")
87+
```
88+
3889
### 🧩 模型列表
3990

4091
| `model_type` | 模型名称 | 推理框架 |模型大小 |推理耗时(单图 60KB)|
@@ -59,11 +110,13 @@ unitable是来源unitable的transformer模型,精度最高,暂仅支持pytor
59110
|:---:|:---|
60111
|v0.x|`rapidocr_onnxruntime`|
61112
|v1.0.x|`rapidocr>=2.0.0,<3.0.0`|
62-
|v1.x.0|`rapidocr>=3.0.0`|
113+
|v2.x|`rapidocr>=3.0.0`|
63114

64115
由于模型较小,预先将slanet-plus表格识别模型(`slanet-plus.onnx`)打包进了whl包内。其余模型在初始化`RapidTable`类时,会根据`model_type`来自动下载模型到安装包所在`models`目录下。当然也可以通过`RapidTableInput(model_path='')`来指定自己模型路径。注意仅限于我们现支持的`model_type`
65116

66-
> ⚠️注意:`rapid_table>=v0.1.0`之后,不再将`rapidocr`依赖强制打包到`rapid_table`中。使用前,需要自行安装`rapidocr_onnxruntime`包。
117+
> > ⚠️注意:`rapid_table>=v1.0.0`之后,不再将`rapidocr`依赖强制打包到`rapid_table`中。使用前,需要自行安装`rapidocr`包。
118+
>
119+
> ⚠️注意:`rapid_table>=v0.1.0,<1.0.0`之后,不再将`rapidocr`依赖强制打包到`rapid_table`中。使用前,需要自行安装`rapidocr_onnxruntime`包。
67120

68121
```bash
69122
pip install rapidocr
@@ -83,90 +136,82 @@ pip install onnxruntime-gpu # for onnx gpu inference
83136

84137
> ⚠️注意:在`rapid_table>=1.0.0`之后,模型输入均采用dataclasses封装,简化和兼容参数传递。输入和输出定义如下:
85138

86-
```python
87-
# 输入
88-
@dataclass
89-
class RapidTableInput:
90-
model_type: Optional[str] = ModelType.SLANETPLUS.value
91-
model_path: Union[str, Path, None, Dict[str, str]] = None
92-
use_cuda: bool = False
93-
device: str = "cpu"
94-
95-
# 输出
96-
@dataclass
97-
class RapidTableOutput:
98-
pred_html: Optional[str] = None
99-
cell_bboxes: Optional[np.ndarray] = None
100-
logic_points: Optional[np.ndarray] = None
101-
elapse: Optional[float] = None
102-
103-
# 使用示例
104-
input_args = RapidTableInput(model_type="unitable")
105-
table_engine = RapidTable(input_args)
106-
107-
img_path = 'test_images/table.jpg'
108-
table_results = table_engine(img_path)
139+
ModelType支持已有的4个模型 ([source](./rapid_table/utils/typings.py)):
109140

110-
print(table_results.pred_html)
141+
```python
142+
class ModelType(Enum):
143+
PPSTRUCTURE_EN = "ppstructure_en"
144+
PPSTRUCTURE_ZH = "ppstructure_zh"
145+
SLANETPLUS = "slanet_plus"
146+
UNITABLE = "unitable"
111147
```
112148

113-
完整示例:
149+
##### CPU使用
114150

115151
```python
116-
from pathlib import Path
117-
118-
from rapidocr import RapidOCR, VisRes
119-
from rapid_table import RapidTable, RapidTableInput, VisTable
120152
121-
# 开启onnx-gpu推理
122-
# input_args = RapidTableInput(use_cuda=True)
123-
# table_engine = RapidTable(input_args)
153+
from rapidocr import RapidOCR
124154
125-
# 使用torch推理版本的unitable模型
126-
# input_args = RapidTableInput(model_type="unitable", use_cuda=True, device="cuda:0")
127-
# table_engine = RapidTable(input_args)
155+
from rapid_table import ModelType, RapidTable, RapidTableInput
128156
129157
ocr_engine = RapidOCR()
130-
vis_ocr = VisRes()
131158
132-
# 默认是slanet_plus模型
133-
input_args = RapidTableInput(model_type="unitable")
159+
input_args = RapidTableInput(model_type=ModelType.UNITABLE)
134160
table_engine = RapidTable(input_args)
135-
viser = VisTable()
136161
137-
img_path = "tests/test_files/table.jpg"
162+
img_path = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg"
138163
139-
# OCR
140-
rapid_ocr_output = ocr_engine(img_path, return_word_box=True)
141-
ocr_result = list(
142-
zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
143-
)
144164
# 使用单字识别
165+
# rapid_ocr_output = ocr_engine(img_path, return_word_box=True)
145166
# word_results = rapid_ocr_output.word_results
146167
# ocr_result = [
147-
# [word_result[2], word_result[0], word_result[1]] for word_result in word_results
168+
# [word_result[0][2], word_result[0][0], word_result[0][1]]
169+
# for word_result in word_results
148170
# ]
149171
150-
table_results = table_engine(img_path, ocr_result)
151-
table_html_str, table_cell_bboxes = table_results.pred_html, table_results.cell_bboxes
152-
# Save
153-
save_dir = Path("outputs")
154-
save_dir.mkdir(parents=True, exist_ok=True)
172+
rapid_ocr_output = ocr_engine(img_path)
173+
ocr_result = list(
174+
zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
175+
)
176+
results = table_engine(img_path, ocr_result)
177+
results.vis(save_dir="outputs", save_name="vis")
178+
```
179+
180+
##### GPU使用
181+
182+
```python
183+
184+
from rapidocr import RapidOCR
155185
156-
save_html_path = save_dir / f"{Path(img_path).stem}.html"
157-
save_drawed_path = save_dir / f"{Path(img_path).stem}_table_vis{Path(img_path).suffix}"
158-
save_logic_points_path = save_dir / f"{Path(img_path).stem}_table_col_row_vis{Path(img_path).suffix}"
186+
from rapid_table import ModelType, RapidTable, RapidTableInput
159187
160-
# Visualize table rec result
161-
vis_imged = viser(img_path, table_results, save_html_path, save_drawed_path, save_logic_points_path)
188+
ocr_engine = RapidOCR()
189+
190+
# onnxruntime-gpu
191+
input_args = RapidTableInput(
192+
model_type=ModelType.SLANETPLUS, engine_cfg={"use_cuda": True, "gpu_id": 1}
193+
)
162194
163-
print(f"The results has been saved {save_dir}")
195+
# torch gpu
196+
# input_args = RapidTableInput(
197+
# model_type=ModelType.UNITABLE,
198+
# engine_cfg={"use_cuda": True, "cuda_ep_cfg.gpu_id": 1},
199+
# )
200+
table_engine = RapidTable(input_args)
201+
202+
img_path = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg"
203+
rapid_ocr_output = ocr_engine(img_path)
204+
ocr_result = list(
205+
zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
206+
)
207+
results = table_engine(img_path, ocr_result)
208+
results.vis(save_dir="outputs", save_name="vis")
164209
```
165210

166211
#### 📦 终端运行
167212

168213
```bash
169-
rapid_table -v -img test_images/table.jpg
214+
rapid_table test_images/table.jpg -v
170215
```
171216

172217
### 📝 结果

demo.py

Lines changed: 19 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,28 @@
11
# -*- encoding: utf-8 -*-
22
# @Author: SWHL
33
# @Contact: liekkaskono@163.com
4-
from pathlib import Path
4+
from rapidocr import RapidOCR
55

6-
from rapidocr import RapidOCR, VisRes
6+
from rapid_table import ModelType, RapidTable, RapidTableInput
77

8-
from rapid_table import RapidTable, RapidTableInput, VisTable
8+
ocr_engine = RapidOCR()
99

10-
if __name__ == "__main__":
11-
# Init
12-
ocr_engine = RapidOCR()
13-
vis_ocr = VisRes()
10+
input_args = RapidTableInput(model_type=ModelType.UNITABLE)
11+
table_engine = RapidTable(input_args)
1412

15-
input_args = RapidTableInput(model_type="unitable")
16-
table_engine = RapidTable(input_args)
17-
viser = VisTable()
13+
img_path = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg"
1814

19-
img_path = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg"
15+
# 使用单字识别
16+
# rapid_ocr_output = ocr_engine(img_path, return_word_box=True)
17+
# word_results = rapid_ocr_output.word_results
18+
# ocr_result = [
19+
# [word_result[0][2], word_result[0][0], word_result[0][1]]
20+
# for word_result in word_results
21+
# ]
2022

21-
# OCR
22-
rapid_ocr_output = ocr_engine(img_path)
23-
ocr_result = list(
24-
zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
25-
)
26-
table_results = table_engine(img_path, ocr_result)
27-
28-
# 使用单字识别
29-
# word_results = rapid_ocr_output.word_results
30-
# ocr_result = [
31-
# [word_result[2], word_result[0], word_result[1]] for word_result in word_results
32-
# ]
33-
# table_results = table_engine(img_path, ocr_result)
34-
35-
table_html_str, table_cell_bboxes = (
36-
table_results.pred_html,
37-
table_results.cell_bboxes,
38-
)
39-
# Save
40-
save_dir = Path("outputs")
41-
save_dir.mkdir(parents=True, exist_ok=True)
42-
43-
save_html_path = save_dir / f"{Path(img_path).stem}.html"
44-
save_drawed_path = (
45-
save_dir / f"{Path(img_path).stem}_table_vis{Path(img_path).suffix}"
46-
)
47-
save_logic_points_path = (
48-
save_dir / f"{Path(img_path).stem}_table_col_row_vis{Path(img_path).suffix}"
49-
)
50-
51-
# Visualize table rec result
52-
vis_imged = viser(
53-
img_path,
54-
table_results,
55-
save_html_path,
56-
save_drawed_path,
57-
save_logic_points_path,
58-
)
59-
60-
print(f"The results has been saved {save_dir}")
23+
rapid_ocr_output = ocr_engine(img_path)
24+
ocr_result = list(
25+
zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
26+
)
27+
results = table_engine(img_path, ocr_result)
28+
results.vis(save_dir="outputs", save_name="vis")

rapid_table/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
# @Author: SWHL
33
# @Contact: liekkaskono@163.com
44
from .main import RapidTable, RapidTableInput
5-
from .utils import VisTable
5+
from .utils import EngineType, ModelType, VisTable

rapid_table/default_models.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
ppstructure_en:
2+
model_dir_or_path: https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/v2.0.0/en_ppstructure_mobile_v2_SLANet.onnx
3+
SHA256: 2cae17d16a16f9df7229e21665fe3fbe06f3ca85b2024772ee3e3142e955aa60
4+
5+
ppstructure_zh:
6+
model_dir_or_path: https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/v2.0.0/ch_ppstructure_mobile_v2_SLANet.onnx
7+
SHA256: ddfc6c97ee4db2a5e9de4de8b6a14508a39d42d228503219fdfebfac364885e3
8+
9+
slanet_plus:
10+
model_dir_or_path: https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/v2.0.0/slanet-plus.onnx
11+
SHA256: d57a942af6a2f57d6a4a0372573c696a2379bf5857c45e2ac69993f3b334514b
12+
13+
unitable:
14+
model_dir_or_path: https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/v2.0.0/unitable
15+
SHA256:
16+
encoder.pth: 2c66b3c6a3d1c86a00985bab2cd79412fc2b668ff39d338bc3c63d383b08684d
17+
decoder.pth: fa342ef3de259576a01a5545ede804208ef35a124935e30df4768e6708dcb6cb
18+
vocab.json: 05037d02c48d106639bc90284aa847e5e2151d4746b3f5efe1628599efbd668a
19+

rapid_table/engine_cfg.yaml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
onnxruntime:
2+
intra_op_num_threads: -1
3+
inter_op_num_threads: -1
4+
enable_cpu_mem_arena: false
5+
6+
cpu_ep_cfg:
7+
arena_extend_strategy: "kSameAsRequested"
8+
9+
use_cuda: false
10+
cuda_ep_cfg:
11+
gpu_id: 0
12+
arena_extend_strategy: "kNextPowerOfTwo"
13+
cudnn_conv_algo_search: "EXHAUSTIVE"
14+
do_copy_in_default_stream: true
15+
16+
use_dml: false
17+
dm_ep_cfg: null
18+
19+
use_cann: false
20+
cann_ep_cfg:
21+
gpu_id: 0
22+
arena_extend_strategy: "kNextPowerOfTwo"
23+
npu_mem_limit: 21474836480 # 20 * 1024 * 1024 * 1024
24+
op_select_impl_mode: "high_performance"
25+
optypelist_for_implmode: "Gelu"
26+
enable_cann_graph: true
27+
28+
openvino:
29+
inference_num_threads: -1
30+
31+
paddle:
32+
cpu_math_library_num_threads: -1
33+
use_cuda: false
34+
gpu_id: 0
35+
gpu_mem: 500
36+
37+
torch:
38+
use_cuda: false
39+
gpu_id: 0
40+
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# -*- encoding: utf-8 -*-
2+
# @Author: SWHL
3+
# @Contact: liekkaskono@163.com

0 commit comments

Comments
 (0)