88from bs4 import BeautifulSoup
99from rapidocr_onnxruntime import RapidOCR
1010
11- from wired_table_rec .utils import rescale_size
11+ from wired_table_rec .main import RapidTableInput , ModelType
12+ from wired_table_rec .utils .utils import rescale_size
1213from wired_table_rec .utils .utils_table_recover import (
1314 plot_html_table ,
1415 is_single_axis_contained ,
2526from wired_table_rec import WiredTableRecognition
2627
2728test_file_dir = cur_dir / "test_files" / "wired"
28-
29- table_recog = WiredTableRecognition ()
29+ input_args = RapidTableInput ( model_type = ModelType . UNET . value )
30+ table_recog = WiredTableRecognition (input_args )
3031ocr_engine = RapidOCR ()
3132
3233
@@ -40,9 +41,13 @@ def get_td_nums(html: str) -> int:
4041
4142def test_squeeze_bug ():
4243 img_path = test_file_dir / "squeeze_error.jpeg"
43- ocr_result , _ = ocr_engine (img_path )
44- table_str , * _ = table_recog (str (img_path ), ocr_result )
45- td_nums = get_td_nums (table_str )
44+ ocr_result , _ = ocr_engine (str (img_path ))
45+ table_results = table_recog (str (img_path ))
46+ table_html_str , table_cell_bboxes = (
47+ table_results .pred_html ,
48+ table_results .cell_bboxes ,
49+ )
50+ td_nums = get_td_nums (table_html_str )
4651 assert td_nums >= 160
4752
4853
@@ -58,9 +63,13 @@ def test_squeeze_bug():
5863def test_input_normal (img_path , gt_td_nums , gt2 ):
5964 img_path = test_file_dir / img_path
6065
61- ocr_result , _ = ocr_engine (img_path )
62- table_str , * _ = table_recog (str (img_path ), ocr_result )
63- td_nums = get_td_nums (table_str )
66+ ocr_result , _ = ocr_engine (str (img_path ))
67+ table_results = table_recog (str (img_path ))
68+ table_html_str , table_cell_bboxes = (
69+ table_results .pred_html ,
70+ table_results .cell_bboxes ,
71+ )
72+ td_nums = get_td_nums (table_html_str )
6473
6574 assert td_nums >= gt_td_nums
6675
@@ -74,9 +83,13 @@ def test_input_normal(img_path, gt_td_nums, gt2):
7483def test_enhance_box_line (img_path , gt_td_nums ):
7584 img_path = test_file_dir / img_path
7685
77- ocr_result , _ = ocr_engine (img_path )
78- table_str , * _ = table_recog (str (img_path ), ocr_result , enhance_box_line = False )
79- td_nums = get_td_nums (table_str )
86+ ocr_result , _ = ocr_engine (str (img_path ))
87+ table_results = table_recog (str (img_path ), enhance_box_line = False )
88+ table_html_str , table_cell_bboxes = (
89+ table_results .pred_html ,
90+ table_results .cell_bboxes ,
91+ )
92+ td_nums = get_td_nums (table_html_str )
8093
8194 assert td_nums <= gt_td_nums
8295
@@ -291,10 +304,13 @@ def test_plot_html_table(logi_points, cell_box_map, expected_html):
291304def test_no_rec_again (img_path , gt_td_nums , gt2 ):
292305 img_path = test_file_dir / img_path
293306
294- ocr_result , _ = ocr_engine (img_path )
295- table_str , * _ = table_recog (str (img_path ), ocr_result , rec_again = False )
296- td_nums = get_td_nums (table_str )
297-
307+ ocr_result , _ = ocr_engine (str (img_path ))
308+ table_results = table_recog (str (img_path ), rec_again = False )
309+ table_html_str , table_cell_bboxes = (
310+ table_results .pred_html ,
311+ table_results .cell_bboxes ,
312+ )
313+ td_nums = get_td_nums (table_html_str )
298314 assert td_nums >= gt_td_nums
299315
300316
@@ -308,12 +324,15 @@ def test_no_rec_again(img_path, gt_td_nums, gt2):
308324def test_no_ocr (img_path , html_output , points_len ):
309325 img_path = test_file_dir / img_path
310326
311- ocr_result , _ = ocr_engine (img_path )
312- html , elasp , polygons , logic_points , ocr_res = table_recog (
313- str (img_path ), ocr_result , need_ocr = False
327+ ocr_result , _ = ocr_engine (str (img_path ))
328+ table_results = table_recog (str (img_path ), need_ocr = False )
329+ table_html_str , table_cell_bboxes , table_logic_points = (
330+ table_results .pred_html ,
331+ table_results .cell_bboxes ,
332+ table_results .logic_points ,
314333 )
315- assert len ( ocr_res ) == 0
316- assert len (polygons ) > points_len
317- assert len (logic_points ) > points_len
318- assert len (polygons ) == len (logic_points )
319- assert html == html_output
334+
335+ assert len (table_cell_bboxes ) > points_len
336+ assert len (table_logic_points ) > points_len
337+ assert len (table_cell_bboxes ) == len (table_logic_points )
338+ assert table_html_str == html_output
0 commit comments