Skip to content

Commit 7fd2549

Browse files
committed
test: fix test
1 parent 6cda006 commit 7fd2549

File tree

2 files changed

+71
-42
lines changed

2 files changed

+71
-42
lines changed

tests/test_lineless_table_rec.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from pathlib import Path
66
import pytest
77

8+
from lineless_table_rec.main import RapidTableInput, ModelType
9+
810
cur_dir = Path(__file__).resolve().parent
911
root_dir = cur_dir.parent
1012

@@ -14,8 +16,8 @@
1416
from lineless_table_rec import LinelessTableRecognition
1517

1618
test_file_dir = cur_dir / "test_files"
17-
18-
table_recog = LinelessTableRecognition()
19+
input_args = RapidTableInput(model_type=ModelType.LORE.value)
20+
table_recog = LinelessTableRecognition(input_args)
1921

2022

2123
@pytest.mark.parametrize(
@@ -27,12 +29,15 @@
2729
)
2830
def test_input_normal(img_path, table_str_len, td_nums):
2931
img_path = test_file_dir / img_path
30-
img = cv2.imread(str(img_path))
3132

32-
table_str, *_ = table_recog(img)
33+
table_results = table_recog(str(img_path))
34+
table_html_str, table_cell_bboxes = (
35+
table_results.pred_html,
36+
table_results.cell_bboxes,
37+
)
3338

34-
assert len(table_str) >= table_str_len
35-
assert table_str.count("td") == td_nums
39+
assert len(table_html_str) >= table_str_len
40+
assert table_html_str.count("td") == td_nums
3641

3742

3843
@pytest.mark.parametrize(
@@ -254,12 +259,15 @@ def test_plot_html_table(logi_points, cell_box_map, expected_html):
254259
)
255260
def test_no_rec_again(img_path, table_str_len, td_nums):
256261
img_path = test_file_dir / img_path
257-
img = cv2.imread(str(img_path))
258262

259-
table_str, *_ = table_recog(img, rec_again=False)
263+
table_results = table_recog(str(img_path), rec_again=False)
264+
table_html_str, table_cell_bboxes = (
265+
table_results.pred_html,
266+
table_results.cell_bboxes,
267+
)
260268

261-
assert len(table_str) >= table_str_len
262-
assert table_str.count("td") == td_nums
269+
assert len(table_html_str) >= table_str_len
270+
assert table_html_str.count("td") == td_nums
263271

264272

265273
@pytest.mark.parametrize(
@@ -271,12 +279,14 @@ def test_no_rec_again(img_path, table_str_len, td_nums):
271279
)
272280
def test_no_ocr(img_path, html_output, points_len):
273281
img_path = test_file_dir / img_path
274-
275-
html, elasp, polygons, logic_points, ocr_res = table_recog(
276-
str(img_path), need_ocr=False
282+
table_results = table_recog(str(img_path), need_ocr=False)
283+
table_html_str, table_cell_bboxes, table_logic_points = (
284+
table_results.pred_html,
285+
table_results.cell_bboxes,
286+
table_results.logic_points,
277287
)
278-
assert len(ocr_res) == 0
279-
assert len(polygons) > points_len
280-
assert len(logic_points) > points_len
281-
assert len(polygons) == len(logic_points)
282-
assert html == html_output
288+
289+
assert len(table_cell_bboxes) > points_len
290+
assert len(table_logic_points) > points_len
291+
assert len(table_cell_bboxes) == len(table_logic_points)
292+
assert table_html_str == html_output

tests/test_wired_table_rec.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from bs4 import BeautifulSoup
99
from rapidocr_onnxruntime import RapidOCR
1010

11-
from wired_table_rec.utils import rescale_size
11+
from wired_table_rec.main import RapidTableInput, ModelType
12+
from wired_table_rec.utils.utils import rescale_size
1213
from wired_table_rec.utils.utils_table_recover import (
1314
plot_html_table,
1415
is_single_axis_contained,
@@ -25,8 +26,8 @@
2526
from wired_table_rec import WiredTableRecognition
2627

2728
test_file_dir = cur_dir / "test_files" / "wired"
28-
29-
table_recog = WiredTableRecognition()
29+
input_args = RapidTableInput(model_type=ModelType.UNET.value)
30+
table_recog = WiredTableRecognition(input_args)
3031
ocr_engine = RapidOCR()
3132

3233

@@ -40,9 +41,13 @@ def get_td_nums(html: str) -> int:
4041

4142
def test_squeeze_bug():
4243
img_path = test_file_dir / "squeeze_error.jpeg"
43-
ocr_result, _ = ocr_engine(img_path)
44-
table_str, *_ = table_recog(str(img_path), ocr_result)
45-
td_nums = get_td_nums(table_str)
44+
ocr_result, _ = ocr_engine(str(img_path))
45+
table_results = table_recog(str(img_path))
46+
table_html_str, table_cell_bboxes = (
47+
table_results.pred_html,
48+
table_results.cell_bboxes,
49+
)
50+
td_nums = get_td_nums(table_html_str)
4651
assert td_nums >= 160
4752

4853

@@ -58,9 +63,13 @@ def test_squeeze_bug():
5863
def test_input_normal(img_path, gt_td_nums, gt2):
5964
img_path = test_file_dir / img_path
6065

61-
ocr_result, _ = ocr_engine(img_path)
62-
table_str, *_ = table_recog(str(img_path), ocr_result)
63-
td_nums = get_td_nums(table_str)
66+
ocr_result, _ = ocr_engine(str(img_path))
67+
table_results = table_recog(str(img_path))
68+
table_html_str, table_cell_bboxes = (
69+
table_results.pred_html,
70+
table_results.cell_bboxes,
71+
)
72+
td_nums = get_td_nums(table_html_str)
6473

6574
assert td_nums >= gt_td_nums
6675

@@ -74,9 +83,13 @@ def test_input_normal(img_path, gt_td_nums, gt2):
7483
def test_enhance_box_line(img_path, gt_td_nums):
7584
img_path = test_file_dir / img_path
7685

77-
ocr_result, _ = ocr_engine(img_path)
78-
table_str, *_ = table_recog(str(img_path), ocr_result, enhance_box_line=False)
79-
td_nums = get_td_nums(table_str)
86+
ocr_result, _ = ocr_engine(str(img_path))
87+
table_results = table_recog(str(img_path), enhance_box_line=False)
88+
table_html_str, table_cell_bboxes = (
89+
table_results.pred_html,
90+
table_results.cell_bboxes,
91+
)
92+
td_nums = get_td_nums(table_html_str)
8093

8194
assert td_nums <= gt_td_nums
8295

@@ -291,10 +304,13 @@ def test_plot_html_table(logi_points, cell_box_map, expected_html):
291304
def test_no_rec_again(img_path, gt_td_nums, gt2):
292305
img_path = test_file_dir / img_path
293306

294-
ocr_result, _ = ocr_engine(img_path)
295-
table_str, *_ = table_recog(str(img_path), ocr_result, rec_again=False)
296-
td_nums = get_td_nums(table_str)
297-
307+
ocr_result, _ = ocr_engine(str(img_path))
308+
table_results = table_recog(str(img_path), rec_again=False)
309+
table_html_str, table_cell_bboxes = (
310+
table_results.pred_html,
311+
table_results.cell_bboxes,
312+
)
313+
td_nums = get_td_nums(table_html_str)
298314
assert td_nums >= gt_td_nums
299315

300316

@@ -308,12 +324,15 @@ def test_no_rec_again(img_path, gt_td_nums, gt2):
308324
def test_no_ocr(img_path, html_output, points_len):
309325
img_path = test_file_dir / img_path
310326

311-
ocr_result, _ = ocr_engine(img_path)
312-
html, elasp, polygons, logic_points, ocr_res = table_recog(
313-
str(img_path), ocr_result, need_ocr=False
327+
ocr_result, _ = ocr_engine(str(img_path))
328+
table_results = table_recog(str(img_path), need_ocr=False)
329+
table_html_str, table_cell_bboxes, table_logic_points = (
330+
table_results.pred_html,
331+
table_results.cell_bboxes,
332+
table_results.logic_points,
314333
)
315-
assert len(ocr_res) == 0
316-
assert len(polygons) > points_len
317-
assert len(logic_points) > points_len
318-
assert len(polygons) == len(logic_points)
319-
assert html == html_output
334+
335+
assert len(table_cell_bboxes) > points_len
336+
assert len(table_logic_points) > points_len
337+
assert len(table_cell_bboxes) == len(table_logic_points)
338+
assert table_html_str == html_output

0 commit comments

Comments
 (0)