Skip to content

Commit fc4c1b9

Browse files
committed
chore: update files
1 parent c1f6670 commit fc4c1b9

File tree

11 files changed

+143
-104
lines changed

11 files changed

+143
-104
lines changed

README.md

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,12 @@ unitable是来源unitable的transformer模型,精度最高,暂仅支持pytor
7878
7979
table_engine = RapidTable(input_args)
8080
81-
img_path = "<https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg>"
82-
rapid_ocr_output = ocr_engine(img_path)
83-
ocr_result = list(
84-
zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
85-
)
86-
results = table_engine(img_path, ocr_result)
81+
img_path = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg"
82+
83+
ori_ocr_res = ocr_engine(img_path)
84+
ocr_results = [ori_ocr_res.boxes, ori_ocr_res.txts, ori_ocr_res.scores]
85+
86+
results = table_engine(img_path, ocr_results=ocr_results)
8787
results.vis(save_dir="outputs", save_name="vis")
8888
```
8989

@@ -162,19 +162,17 @@ table_engine = RapidTable(input_args)
162162
163163
img_path = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg"
164164
165-
# 使用单字识别
166-
# rapid_ocr_output = ocr_engine(img_path, return_word_box=True)
167-
# word_results = rapid_ocr_output.word_results
168-
# ocr_result = [
165+
# # 使用单字识别
166+
# ori_ocr_res = ocr_engine(img_path, return_word_box=True)
167+
# ocr_results = [
169168
# [word_result[0][2], word_result[0][0], word_result[0][1]]
170-
# for word_result in word_results
169+
# for word_result in ori_ocr_res.word_results
171170
# ]
171+
# ocr_results = list(zip(*ocr_results))
172172
173-
rapid_ocr_output = ocr_engine(img_path)
174-
ocr_result = list(
175-
zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
176-
)
177-
results = table_engine(img_path, ocr_result)
173+
ori_ocr_res = ocr_engine(img_path)
174+
ocr_results = [ori_ocr_res.boxes, ori_ocr_res.txts, ori_ocr_res.scores]
175+
results = table_engine(img_path, ocr_results=ocr_results)
178176
results.vis(save_dir="outputs", save_name="vis")
179177
```
180178

@@ -201,11 +199,11 @@ input_args = RapidTableInput(
201199
table_engine = RapidTable(input_args)
202200
203201
img_path = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg"
204-
rapid_ocr_output = ocr_engine(img_path)
205-
ocr_result = list(
206-
zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
207-
)
208-
results = table_engine(img_path, ocr_result)
202+
203+
ori_ocr_res = ocr_engine(img_path)
204+
ocr_results = [ori_ocr_res.boxes, ori_ocr_res.txts, ori_ocr_res.scores]
205+
206+
results = table_engine(img_path, ocr_results=ocr_results)
209207
results.vis(save_dir="outputs", save_name="vis")
210208
```
211209

demo.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,18 @@
1010
input_args = RapidTableInput(model_type=ModelType.UNITABLE)
1111
table_engine = RapidTable(input_args)
1212

13-
img_path = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg"
13+
img_path = "tests/test_files/table_without_txt.jpg"
14+
# img_path = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg"
1415

15-
# 使用单字识别
16-
# rapid_ocr_output = ocr_engine(img_path, return_word_box=True)
17-
# word_results = rapid_ocr_output.word_results
18-
# ocr_result = [
16+
# # 使用单字识别
17+
# ori_ocr_res = ocr_engine(img_path, return_word_box=True)
18+
# ocr_results = [
1919
# [word_result[0][2], word_result[0][0], word_result[0][1]]
20-
# for word_result in word_results
20+
# for word_result in ori_ocr_res.word_results
2121
# ]
22+
# ocr_results = list(zip(*ocr_results))
2223

23-
rapid_ocr_output = ocr_engine(img_path)
24-
ocr_result = list(
25-
zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
26-
)
27-
results = table_engine(img_path, ocr_result)
24+
ori_ocr_res = ocr_engine(img_path)
25+
ocr_results = [ori_ocr_res.boxes, ori_ocr_res.txts, ori_ocr_res.scores]
26+
results = table_engine(img_path)
2827
results.vis(save_dir="outputs", save_name="vis")

rapid_table/main.py

Lines changed: 40 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ModelType,
1818
RapidTableInput,
1919
RapidTableOutput,
20+
get_boxes_recs,
2021
import_package,
2122
)
2223

@@ -34,7 +35,11 @@ def __init__(self, cfg: Optional[RapidTableInput] = None):
3435

3536
self.cfg = cfg
3637
self.table_structure = self._init_table_structer()
37-
self.ocr_engine = self._init_ocr_engine()
38+
39+
self.ocr_engine = None
40+
if cfg.use_ocr:
41+
self.ocr_engine = self._init_ocr_engine()
42+
3843
self.table_matcher = TableMatch()
3944
self.load_img = LoadImage()
4045

@@ -58,72 +63,48 @@ def _init_table_structer(self):
5863
def __call__(
5964
self,
6065
img_content: Union[str, np.ndarray, bytes, Path],
61-
ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
66+
ocr_results: Optional[Tuple[np.ndarray, Tuple[str], Tuple[float]]] = None,
6267
) -> RapidTableOutput:
63-
if self.ocr_engine is None and ocr_result is None:
64-
raise ValueError(
65-
"One of two conditions must be met: ocr_result is not empty, or rapidocr is installed."
66-
)
68+
s = time.perf_counter()
6769

6870
img = self.load_img(img_content)
6971

70-
s = time.perf_counter()
71-
h, w = img.shape[:2]
72-
73-
if ocr_result is None:
74-
ocr_result = self.ocr_engine(img)
75-
ocr_result = list(
76-
zip(
77-
ocr_result.boxes,
78-
ocr_result.txts,
79-
ocr_result.scores,
80-
)
81-
)
82-
dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
72+
dt_boxes, rec_res = self.get_ocr_results(img, ocr_results)
73+
pred_structures, cell_bboxes, logic_points = self.get_table_rec_results(img)
74+
pred_html = self.get_table_matcher(
75+
pred_structures, cell_bboxes, dt_boxes, rec_res
76+
)
8377

84-
pred_structures, cell_bboxes, _ = self.table_structure(img)
78+
elapse = time.perf_counter() - s
79+
return RapidTableOutput(img, pred_html, cell_bboxes, logic_points, elapse)
8580

86-
# 适配slanet-plus模型输出的box缩放还原
87-
if self.cfg.model_type == ModelType.SLANETPLUS:
88-
cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
81+
def get_ocr_results(
82+
self, img: np.ndarray, ocr_results: Tuple[np.ndarray, Tuple[str], Tuple[float]]
83+
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
84+
if ocr_results is not None:
85+
return get_boxes_recs(ocr_results, img.shape[:2])
8986

90-
pred_html = self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res)
87+
if not self.cfg.use_ocr:
88+
return None, None
9189

92-
# 过滤掉占位的bbox
93-
mask = ~np.all(cell_bboxes == 0, axis=1)
94-
cell_bboxes = cell_bboxes[mask]
90+
ori_ocr_res = self.ocr_engine(img)
91+
if ori_ocr_res.boxes is None:
92+
logger.warning("OCR Result is empty")
93+
return None, None
9594

95+
ocr_results = [ori_ocr_res.boxes, ori_ocr_res.txts, ori_ocr_res.scores]
96+
return get_boxes_recs(ocr_results, img.shape[:2])
97+
98+
def get_table_rec_results(self, img: np.ndarray):
99+
pred_structures, cell_bboxes, _ = self.table_structure(img)
96100
logic_points = self.table_matcher.decode_logic_points(pred_structures)
97-
elapse = time.perf_counter() - s
98-
return RapidTableOutput(img, pred_html, cell_bboxes, logic_points, elapse)
101+
return pred_structures, cell_bboxes, logic_points
99102

100-
def get_boxes_recs(
101-
self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
102-
) -> Tuple[np.ndarray, Tuple[str, str]]:
103-
dt_boxes, rec_res, scores = list(zip(*ocr_result))
104-
rec_res = list(zip(rec_res, scores))
105-
106-
r_boxes = []
107-
for box in dt_boxes:
108-
box = np.array(box)
109-
x_min = max(0, box[:, 0].min() - 1)
110-
x_max = min(w, box[:, 0].max() + 1)
111-
y_min = max(0, box[:, 1].min() - 1)
112-
y_max = min(h, box[:, 1].max() + 1)
113-
box = [x_min, y_min, x_max, y_max]
114-
r_boxes.append(box)
115-
dt_boxes = np.array(r_boxes)
116-
return dt_boxes, rec_res
117-
118-
def adapt_slanet_plus(self, img: np.ndarray, cell_bboxes: np.ndarray) -> np.ndarray:
119-
h, w = img.shape[:2]
120-
resized = 488
121-
ratio = min(resized / h, resized / w)
122-
w_ratio = resized / (w * ratio)
123-
h_ratio = resized / (h * ratio)
124-
cell_bboxes[:, 0::2] *= w_ratio
125-
cell_bboxes[:, 1::2] *= h_ratio
126-
return cell_bboxes
103+
def get_table_matcher(self, pred_structures, cell_bboxes, dt_boxes, rec_res):
104+
if dt_boxes is None and rec_res is None:
105+
return None
106+
107+
return self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res)
127108

128109

129110
def parse_args(arg_list: Optional[List[str]] = None):
@@ -158,11 +139,9 @@ def main(arg_list: Optional[List[str]] = None):
158139
if table_engine.ocr_engine is None:
159140
raise ValueError("ocr engine is None")
160141

161-
rapid_ocr_output = table_engine.ocr_engine(img_path)
162-
ocr_result = list(
163-
zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
164-
)
165-
table_results = table_engine(img_path, ocr_result)
142+
ori_ocr_res = table_engine.ocr_engine(img_path)
143+
ocr_results = [ori_ocr_res.boxes, ori_ocr_res.txts, ori_ocr_res.scores]
144+
table_results = table_engine(img_path, ocr_results=ocr_results)
166145
print(table_results.pred_html)
167146

168147
if args.vis:

rapid_table/table_matcher/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(self, filter_ocr_result=True, use_master=False):
2525
def __call__(self, pred_structures, cell_bboxes, dt_boxes, rec_res):
2626
if self.filter_ocr_result:
2727
dt_boxes, rec_res = self._filter_ocr_result(cell_bboxes, dt_boxes, rec_res)
28+
2829
matched_index = self.match_result(dt_boxes, cell_bboxes)
2930
pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
3031
return pred_html

rapid_table/table_structure/pp_structure/main.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import numpy as np
1818

19-
from rapid_table.utils.typings import EngineType
19+
from rapid_table.utils.typings import EngineType, ModelType
2020

2121
from ...inference_engine.base import get_engine
2222
from ..utils import get_struct_str
@@ -29,22 +29,46 @@ def __init__(self, cfg: Dict[str, Any]):
2929
if cfg["engine_type"] is None:
3030
cfg["engine_type"] = EngineType.ONNXRUNTIME
3131
self.session = get_engine(cfg["engine_type"])(cfg)
32+
self.cfg = cfg
3233

3334
self.preprocess_op = TablePreprocess()
3435

3536
self.character = self.session.get_character_list()
3637
self.postprocess_op = TableLabelDecode(self.character)
3738

38-
def __call__(self, img: np.ndarray) -> Tuple[List[str], np.ndarray, float]:
39+
def __call__(self, ori_img: np.ndarray) -> Tuple[List[str], np.ndarray, float]:
3940
s = time.perf_counter()
4041

41-
img, shape_list = self.preprocess_op(img)
42+
img, shape_list = self.preprocess_op(ori_img)
4243

4344
bbox_preds, struct_probs = self.session(img.copy())
4445

4546
post_result = self.postprocess_op(bbox_preds, struct_probs, [shape_list])
47+
4648
table_struct_str = get_struct_str(post_result["structure_batch_list"][0][0])
47-
bbox_list = post_result["bbox_batch_list"][0]
49+
cell_bboxes = post_result["bbox_batch_list"][0]
50+
51+
if self.cfg["model_type"] == ModelType.SLANETPLUS:
52+
cell_bboxes = self.rescale_cell_bboxes(ori_img, cell_bboxes)
53+
cell_bboxes = self.filter_blank_bbox(cell_bboxes)
4854

4955
elapse = time.perf_counter() - s
50-
return table_struct_str, bbox_list, elapse
56+
return table_struct_str, cell_bboxes, elapse
57+
58+
def rescale_cell_bboxes(
59+
self, img: np.ndarray, cell_bboxes: np.ndarray
60+
) -> np.ndarray:
61+
h, w = img.shape[:2]
62+
resized = 488
63+
ratio = min(resized / h, resized / w)
64+
w_ratio = resized / (w * ratio)
65+
h_ratio = resized / (h * ratio)
66+
cell_bboxes[:, 0::2] *= w_ratio
67+
cell_bboxes[:, 1::2] *= h_ratio
68+
return cell_bboxes
69+
70+
@staticmethod
71+
def filter_blank_bbox(cell_bboxes: np.ndarray) -> np.ndarray:
72+
# 过滤掉占位的bbox
73+
mask = ~np.all(cell_bboxes == 0, axis=1)
74+
return cell_bboxes[mask]

rapid_table/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
from .load_image import LoadImage
66
from .logger import Logger
77
from .typings import EngineType, ModelType, RapidTableInput, RapidTableOutput
8-
from .utils import import_package, is_url, mkdir, read_yaml
8+
from .utils import get_boxes_recs, import_package, is_url, mkdir, read_yaml
99
from .vis import VisTable

rapid_table/utils/typings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class RapidTableInput:
2929
model_type: Optional[ModelType] = ModelType.SLANETPLUS
3030
model_dir_or_path: Union[str, Path, None, Dict[str, str]] = None
3131

32+
use_ocr: bool = True
33+
3234
engine_type: Optional[EngineType] = None
3335
engine_cfg: dict = field(default_factory=dict)
3436

rapid_table/utils/utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,33 @@
44
import hashlib
55
import importlib
66
from pathlib import Path
7-
from typing import Union
7+
from typing import Tuple, Union
88
from urllib.parse import urlparse
99

1010
import cv2
1111
import numpy as np
1212
from omegaconf import DictConfig, OmegaConf
1313

1414

15+
def get_boxes_recs(
16+
ocr_results: Tuple[np.ndarray, Tuple[str], Tuple[float]],
17+
img_shape: Tuple[int, int],
18+
) -> Tuple[np.ndarray, Tuple[str, str]]:
19+
rec_res = list(zip(ocr_results[1], ocr_results[2]))
20+
21+
h, w = img_shape
22+
dt_boxes = []
23+
for box in ocr_results[0]:
24+
box = np.array(box)
25+
x_min = max(0, box[:, 0].min() - 1)
26+
x_max = min(w, box[:, 0].max() + 1)
27+
y_min = max(0, box[:, 1].min() - 1)
28+
y_max = min(h, box[:, 1].max() + 1)
29+
box = [x_min, y_min, x_max, y_max]
30+
dt_boxes.append(box)
31+
return np.array(dt_boxes), rec_res
32+
33+
1534
def save_img(save_path: Union[str, Path], img: np.ndarray):
1635
cv2.imwrite(str(save_path), img)
1736

rapid_table/utils/vis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __call__(
2424
save_drawed_path: Optional[str] = None,
2525
save_logic_path: Optional[str] = None,
2626
):
27-
if save_html_path:
27+
if pred_html and save_html_path:
2828
html_with_border = self.insert_border_style(pred_html)
2929
save_txt(save_html_path, html_with_border)
3030
self.logger.info(f"Save HTML to {save_html_path}")
34.1 KB
Loading

0 commit comments

Comments
 (0)