Skip to content

Commit 30c4b0c

Browse files
committed
feat: extract param for wiredV2
1 parent d1d4db8 commit 30c4b0c

File tree

8 files changed

+104
-34
lines changed

8 files changed

+104
-34
lines changed

README.md

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
</div>
1414

1515
### 最近更新
16-
- **2024.10.13**
17-
- 补充最新paddlex-SLANet-plus 测评结果(已集成模型到[RapidTable](https://github.com/RapidAI/RapidTable)仓库)
1816
- **2024.10.22**
1917
- 补充复杂背景多表格检测提取方案[RapidTableDet](https://github.com/RapidAI/RapidTableDetection)
2018
- **2024.10.29**
2119
- 使用yolo11重新训练表格分类器,修正wired_table_rec v2逻辑坐标还原错误,并更新测评
20+
- **2024.11.12**
21+
- 抽离模型识别和处理过程核心阈值,方便大家进行微调适配自己的场景
2222

2323
### 简介
2424
💖该仓库是用来对文档中表格做结构化识别的推理库,包括来自阿里读光有线和无线表格识别模型,llaipython(微信)贡献的有线表格模型,网易Qanything内置表格分类模型等。
@@ -68,6 +68,7 @@
6868
wired_table_rec_v2(有线表格精度最高): 通用场景有线表格(论文,杂志,期刊, 收据,单据,账单)
6969

7070
paddlex-SLANet-plus(综合精度最高): 文档场景表格(论文,杂志,期刊中的表格)
71+
[微调入参参考](#核心参数)
7172

7273
### 安装
7374

@@ -100,12 +101,6 @@ else:
100101
html, elasp, polygons, logic_points, ocr_res = table_engine(img_path)
101102
print(f"elasp: {elasp}")
102103

103-
#仅返回表格物理box和行列逻辑坐标,不进行ocr识别
104-
#html, elasp, polygons, logic_points, ocr_res = table_engine(img_path, need_ocr=False)
105-
106-
#默认没有匹配的表格框进行了ocr再识别,取消该行为
107-
#html, elasp, polygons, logic_points, ocr_res = table_engine(img_path, rec_again=False)
108-
109104
# 使用其他ocr模型
110105
#ocr_engine =RapidOCR(det_model_dir="xxx/det_server_infer.onnx",rec_model_dir="xxx/rec_server_infer.onnx")
111106
#ocr_res, _ = ocr_engine(img_path)
@@ -164,6 +159,27 @@ for i, res in enumerate(result):
164159
# cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img)
165160
```
166161

162+
### 核心参数
163+
```python
164+
wired_table_rec = WiredTableRecognition()
165+
html, elasp, polygons, logic_points, ocr_res = wired_table_rec(
166+
img_path,
167+
version="v2", #默认使用v2线框模型,切换阿里读光模型可改为v1
168+
morph_close=True, # 是否进行形态学操作,辅助找到更多线框,默认为True
169+
more_h_lines=True, # 是否基于线框检测结果进行更多水平线检查,辅助找到更小线框, 默认为True
170+
more_v_lines=True, # 是否基于线框检测结果进行更多垂直线检查,辅助找到更小线框, 默认为True
171+
extend_line=True, # 是否基于线框检测结果进行线段延长,辅助找到更多线框, 默认为True
172+
need_ocr=True, # 是否进行OCR识别, 默认为True
173+
rec_again=True,# 是否针对未识别到文字的表格框,进行单独截取再识别,默认为True
174+
)
175+
lineless_table_rec = LinelessTableRecognition()
176+
html, elasp, polygons, logic_points, ocr_res = lineless_table_rec(
177+
need_ocr=True, # 是否进行OCR识别, 默认为True
178+
rec_again=True,# 是否针对未识别到文字的表格框,进行单独截取再识别,默认为True
179+
)
180+
```
181+
182+
167183
## FAQ (Frequently Asked Questions)
168184
1. **问:识别框丢失了内部文字信息**
169185
- 答:默认使用的rapidocr小模型,如果需要更高精度的效果,可以从 [模型列表](https://rapidai.github.io/RapidOCRDocs/model_list/#_1)

demo_wired.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,17 @@
1515

1616
table_rec = WiredTableRecognition()
1717

18-
img_path = "tests/test_files/wired/table1.png"
19-
html, elasp, polygons, logic_points, ocr_res = table_rec(img_path)
18+
img_path = "tests/test_files/wired/wired_big_box.png"
19+
html, elasp, polygons, logic_points, ocr_res = table_rec(
20+
img_path,
21+
version="v2", # 默认使用v2线框模型,切换阿里读光模型可改为v1
22+
morph_close=True, # 是否进行形态学操作,辅助找到更多线框,默认为True
23+
more_h_lines=True, # 是否基于线框检测结果进行更多水平线检查,辅助找到更小线框, 默认为True
24+
more_v_lines=True, # 是否基于线框检测结果进行更多垂直线检查,辅助找到更小线框, 默认为True
25+
extend_line=True, # 是否基于线框检测结果进行线段延长,辅助找到更多线框, 默认为True
26+
need_ocr=True, # 是否进行OCR识别, 默认为True
27+
rec_again=True, # 是否针对未识别到文字的表格框,进行单独截取再识别,默认为True
28+
)
2029

2130
print(f"cost: {elasp:.5f}")
2231

@@ -29,6 +38,6 @@
2938
plot_rec_box_with_logic_info(
3039
img_path, f"{output_dir}/table_rec_box.jpg", logic_points, polygons
3140
)
32-
plot_rec_box(img_path, f"{output_dir}/ocr_box.jpg", ocr_res)
41+
plot_rec_box(f"{output_dir}/table_rec_box.jpg", f"{output_dir}/ocr_box.jpg", ocr_res)
3342

3443
print(f"The results has been saved under {output_dir}")
175 KB
Loading

tests/test_wired_table_rec.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,22 @@ def test_input_normal(img_path, gt_td_nums, gt2):
6565
assert td_nums >= gt_td_nums
6666

6767

68+
@pytest.mark.parametrize(
69+
"img_path, gt_td_nums",
70+
[
71+
("wired_big_box.png", 70),
72+
],
73+
)
74+
def test_input_normal(img_path, gt_td_nums):
75+
img_path = test_file_dir / img_path
76+
77+
ocr_result, _ = ocr_engine(img_path)
78+
table_str, *_ = table_recog(str(img_path), ocr_result)
79+
td_nums = get_td_nums(table_str)
80+
81+
assert td_nums >= gt_td_nums
82+
83+
6884
@pytest.mark.parametrize(
6985
"box1, box2, threshold, expected",
7086
[

wired_table_rec/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __call__(
6464
rec_again = kwargs.get("rec_again", True)
6565
need_ocr = kwargs.get("need_ocr", True)
6666
img = self.load_img(img)
67-
polygons = self.table_line_rec(img)
67+
polygons = self.table_line_rec(img, **kwargs)
6868
if polygons is None:
6969
logging.warning("polygons is None.")
7070
return "", 0.0, None, None, None

wired_table_rec/table_line_rec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, model_path: Optional[str] = None):
3636

3737
self.session = OrtInferSession(model_path)
3838

39-
def __call__(self, img: np.ndarray) -> Optional[np.ndarray]:
39+
def __call__(self, img: np.ndarray, **kwargs) -> Optional[np.ndarray]:
4040
img_info = self.preprocess(img)
4141
pred = self.infer(img_info)
4242
polygons = self.postprocess(pred)

wired_table_rec/table_line_rec_plus.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,12 @@ def __init__(self, model_path: Optional[str] = None):
3131

3232
self.session = OrtInferSession(model_path)
3333

34-
def __call__(self, img: np.ndarray) -> Optional[np.ndarray]:
34+
def __call__(self, img: np.ndarray, **kwargs) -> Optional[np.ndarray]:
3535
img_info = self.preprocess(img)
3636
pred = self.infer(img_info)
37-
polygons = self.postprocess(img, pred)
37+
polygons = self.postprocess(img, pred, **kwargs)
3838
if polygons.size == 0:
3939
return None
40-
4140
polygons = polygons.reshape(polygons.shape[0], 4, 2)
4241
polygons[:, 3, :], polygons[:, 1, :] = (
4342
polygons[:, 1, :].copy(),
@@ -68,7 +67,25 @@ def infer(self, input):
6867
result = result[0].astype(np.uint8)
6968
return result
7069

71-
def postprocess(self, img, pred, row=50, col=30, alph=15, angle=50):
70+
def postprocess(self, img, pred, **kwargs):
71+
row = kwargs.get("row", 50) if kwargs else 50
72+
col = kwargs.get("col", 30) if kwargs else 30
73+
h_lines_threshold = kwargs.get("h_lines_threshold", 100) if kwargs else 100
74+
v_lines_threshold = kwargs.get("v_lines_threshold", 15) if kwargs else 15
75+
angle = kwargs.get("angle", 50) if kwargs else 50
76+
morph_close = (
77+
kwargs.get("morph_close", True) if kwargs else True
78+
) # 是否进行闭合运算以找到更多小的框
79+
more_h_lines = (
80+
kwargs.get("more_h_lines", True) if kwargs else True
81+
) # 是否调整以找到更多的横线
82+
more_v_lines = (
83+
kwargs.get("more_v_lines", True) if kwargs else True
84+
) # 是否调整以找到更多的横线
85+
extend_line = (
86+
kwargs.get("extend_line", True) if kwargs else True
87+
) # 是否进行线段延长使得端点连接
88+
7289
ori_shape = img.shape
7390
pred = np.uint8(pred)
7491
hpred = copy.deepcopy(pred) # 横线
@@ -89,16 +106,19 @@ def postprocess(self, img, pred, row=50, col=30, alph=15, angle=50):
89106
vpred = cv2.morphologyEx(
90107
vpred, cv2.MORPH_CLOSE, vkernel, iterations=1
91108
) # 先膨胀后腐蚀的过程
92-
hpred = cv2.morphologyEx(hpred, cv2.MORPH_CLOSE, hkernel, iterations=1)
109+
if morph_close:
110+
hpred = cv2.morphologyEx(hpred, cv2.MORPH_CLOSE, hkernel, iterations=1)
93111
colboxes = get_table_line(vpred, axis=1, lineW=col) # 竖线
94112
rowboxes = get_table_line(hpred, axis=0, lineW=row) # 横线
95-
# rboxes_row_, rboxes_col_ = adjust_lines(rowboxes, colboxes, alph = alph, angle=angle)
96-
rboxes_row_ = adjust_lines(rowboxes, alph=100, angle=angle)
97-
rboxes_col_ = adjust_lines(colboxes, alph=alph, angle=angle)
113+
rboxes_row_, rboxes_col_ = [], []
114+
if more_h_lines:
115+
rboxes_row_ = adjust_lines(rowboxes, alph=h_lines_threshold, angle=angle)
116+
if more_v_lines:
117+
rboxes_col_ = adjust_lines(colboxes, alph=v_lines_threshold, angle=angle)
98118
rowboxes += rboxes_row_
99119
colboxes += rboxes_col_
100-
rowboxes, colboxes = final_adjust_lines(rowboxes, colboxes)
101-
120+
if extend_line:
121+
rowboxes, colboxes = final_adjust_lines(rowboxes, colboxes)
102122
tmp = np.zeros(img.shape[:2], dtype="uint8")
103123
tmp = draw_lines(tmp, rowboxes + colboxes, color=255, lineW=2)
104124
labels = measure.label(tmp < 255, connectivity=2) # 8连通区域标记

wired_table_rec/utils_table_recover.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,17 @@ def plot_rec_box_with_logic_info(img_path, output_path, logic_points, sorted_pol
267267

268268
cv2.putText(
269269
img,
270-
f"{idx}-{logic_points[idx]}",
271-
(x1, y1),
270+
f"row:{logic_points[idx][0]}-{logic_points[idx][1]}",
271+
(x0 + 1, y0 + 15),
272+
cv2.FONT_HERSHEY_PLAIN,
273+
font_scale,
274+
(0, 0, 255),
275+
thickness,
276+
)
277+
cv2.putText(
278+
img,
279+
f"col:{logic_points[idx][2]}-{logic_points[idx][3]}",
280+
(x0 + 1, y0 + 40),
272281
cv2.FONT_HERSHEY_PLAIN,
273282
font_scale,
274283
(0, 0, 255),
@@ -303,15 +312,15 @@ def plot_rec_box(img_path, output_path, sorted_polygons):
303312
font_scale = 1.0 # 原先是0.5
304313
thickness = 2 # 原先是1
305314

306-
cv2.putText(
307-
img,
308-
str(idx),
309-
(x1, y1),
310-
cv2.FONT_HERSHEY_PLAIN,
311-
font_scale,
312-
(0, 0, 255),
313-
thickness,
314-
)
315+
# cv2.putText(
316+
# img,
317+
# str(idx),
318+
# (x1, y1),
319+
# cv2.FONT_HERSHEY_PLAIN,
320+
# font_scale,
321+
# (0, 0, 255),
322+
# thickness,
323+
# )
315324
os.makedirs(os.path.dirname(output_path), exist_ok=True)
316325
# 保存绘制后的图像
317326
cv2.imwrite(output_path, img)

0 commit comments

Comments
 (0)