11import os
22
3+ import numpy as np
34import pytest
5+ import torch
6+ from PIL import Image
47from transformers .models .table_transformer .modeling_table_transformer import (
58 TableTransformerDecoder ,
69)
1114skip_outside_ci = os .getenv ("CI" , "" ).lower () in {"" , "false" , "f" , "0" }
1215
1316
17+ @pytest .fixture ()
18+ def table_transformer ():
19+ table_model = tables .UnstructuredTableTransformerModel ()
20+ table_model .initialize (model = "microsoft/table-transformer-structure-recognition" )
21+ return table_model
22+
23+
24+ @pytest .fixture ()
25+ def example_image ():
26+ return Image .open ("./sample-docs/table-multi-row-column-cells.png" ).convert ("RGB" )
27+
28+
1429@pytest .mark .parametrize (
1530 "model_path" ,
1631 [
@@ -328,13 +343,8 @@ def test_align_rows(rows, bbox, output):
328343 assert postprocess .align_rows (rows , bbox ) == output
329344
330345
331- def test_table_prediction_tesseract ():
332- table_model = tables .UnstructuredTableTransformerModel ()
333- from PIL import Image
334-
335- table_model .initialize (model = "microsoft/table-transformer-structure-recognition" )
336- img = Image .open ("./sample-docs/table-multi-row-column-cells.png" ).convert ("RGB" )
337- prediction = table_model .predict (img )
346+ def test_table_prediction_tesseract (table_transformer , example_image ):
347+ prediction = table_transformer .predict (example_image )
338348 # assert rows spans two rows are detected
339349 assert '<table><thead><th rowspan="2">' in prediction
340350 # one of the safest rows to detect should be present
@@ -351,28 +361,24 @@ def test_table_prediction_tesseract():
351361
352362
353363@pytest .mark .skipif (skip_outside_ci , reason = "Skipping paddle test run outside of CI" )
354- def test_table_prediction_paddle (monkeypatch ):
364+ def test_table_prediction_paddle (monkeypatch , example_image ):
355365 monkeypatch .setenv ("TABLE_OCR" , "paddle" )
356366 table_model = tables .UnstructuredTableTransformerModel ()
357- from PIL import Image
358367
359368 table_model .initialize (model = "microsoft/table-transformer-structure-recognition" )
360- img = Image .open ("./sample-docs/table-multi-row-column-cells.png" ).convert ("RGB" )
361- prediction = table_model .predict (img )
369+ prediction = table_model .predict (example_image )
362370 # Note(yuming): lossen paddle table prediction output test since performance issue
363371 # and results are different in different platforms (i.e., gpu vs cpu)
364372 assert len (prediction )
365373
366374
367- def test_table_prediction_invalid_table_ocr (monkeypatch ):
375+ def test_table_prediction_invalid_table_ocr (monkeypatch , example_image ):
368376 monkeypatch .setenv ("TABLE_OCR" , "invalid_table_ocr" )
369377 with pytest .raises (ValueError ):
370378 table_model = tables .UnstructuredTableTransformerModel ()
371- from PIL import Image
372379
373380 table_model .initialize (model = "microsoft/table-transformer-structure-recognition" )
374- img = Image .open ("./sample-docs/table-multi-row-column-cells.png" ).convert ("RGB" )
375- _ = table_model .predict (img )
381+ _ = table_model .predict (example_image )
376382
377383
378384def test_intersect ():
@@ -581,3 +587,40 @@ def test_cells_to_html():
581587 "cols</td></tr><tr><td></td><td>sub cell 1</td><td>sub cell 2</td></tr></table>"
582588 )
583589 assert tables .cells_to_html (cells ) == expected
590+
591+
592+ def test_padded_results_has_right_dimensions (table_transformer , example_image ):
593+ str_class_name2idx = tables .get_class_map ("structure" )
594+ # a simpler mapping so we keep all structure in the returned objs below for test
595+ str_class_idx2name = {v : "table cell" for v in str_class_name2idx .values ()}
596+ # pad size is no more than 10% of the original image so we can setup test below easier
597+ pad = int (min (example_image .size ) / 10 )
598+
599+ structure = table_transformer .get_structure (example_image , pad_for_structure_detection = pad )
600+ # boxes deteced OUTSIDE of the original image; this shouldn't happen but we want to make sure
601+ # the code handles it as expected
602+ structure ["pred_boxes" ][0 ][0 , :2 ] = 0.5
603+ structure ["pred_boxes" ][0 ][0 , 2 :] = 1.0
604+ # mock a box we know are safly inside the original image with known positions
605+ width , height = example_image .size
606+ padded_width = width + pad * 2
607+ padded_height = height + pad * 2
608+ original = [1 , 3 , 101 , 53 ]
609+ structure ["pred_boxes" ][0 ][1 , :] = torch .tensor (
610+ [
611+ (51 + pad ) / padded_width ,
612+ (28 + pad ) / padded_height ,
613+ 100 / padded_width ,
614+ 50 / padded_height ,
615+ ],
616+ )
617+ objs = tables .outputs_to_objects (structure , example_image .size , str_class_idx2name )
618+ np .testing .assert_almost_equal (objs [0 ]["bbox" ], [- pad , - pad , width + pad , height + pad ], 4 )
619+ np .testing .assert_almost_equal (objs [1 ]["bbox" ], original , 4 )
620+ # a more strict test would be to constrain the actual detected boxes to be within the original
621+ # image but that requires the table transformer to behave in certain ways and do not
622+ # actually test the padding math; so here we use the relaxed condition
623+ for obj in objs [2 :]:
624+ x1 , y1 , x2 , y2 = obj ["bbox" ]
625+ assert max (x1 , x2 ) < width + pad
626+ assert max (y1 , y2 ) < height + pad
0 commit comments