Skip to content

Commit cb2aff2

Browse files
authored
fix: padded boxes are not rescaled/shifted correctly (#229)
1 parent 35ebea7 commit cb2aff2

File tree

4 files changed

+67
-20
lines changed

4 files changed

+67
-20
lines changed

CHANGELOG.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
## 0.6.3
2+
3+
* fix a bug where padded table structure bounding boxes are not shifted back into the original image coordinates correctly
4+
15
## 0.6.2
26

37
* move the confidence threshold for table transformer to config
48

59
## 0.6.1
610

711
* YoloX_quantized is now the default model. This models detects most diverse types and detect tables better than previous model.
8-
* Since detection models tend to nest elements inside others(specifically in Tables), an algorithm has been added for reducing this
9-
behavior. Now all the elements produced by detection models are disjoint and they don't produce overlapping regions, which helps
12+
* Since detection models tend to nest elements inside others(specifically in Tables), an algorithm has been added for reducing this
13+
behavior. Now all the elements produced by detection models are disjoint and they don't produce overlapping regions, which helps
1014
reduce duplicated content.
1115
* Add `source` property to our elements, so you can know where the information was generated (OCR or detection model)
1216

test_unstructured_inference/models/test_tables.py

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import os
22

3+
import numpy as np
34
import pytest
5+
import torch
6+
from PIL import Image
47
from transformers.models.table_transformer.modeling_table_transformer import (
58
TableTransformerDecoder,
69
)
@@ -11,6 +14,18 @@
1114
skip_outside_ci = os.getenv("CI", "").lower() in {"", "false", "f", "0"}
1215

1316

17+
@pytest.fixture()
18+
def table_transformer():
19+
table_model = tables.UnstructuredTableTransformerModel()
20+
table_model.initialize(model="microsoft/table-transformer-structure-recognition")
21+
return table_model
22+
23+
24+
@pytest.fixture()
25+
def example_image():
26+
return Image.open("./sample-docs/table-multi-row-column-cells.png").convert("RGB")
27+
28+
1429
@pytest.mark.parametrize(
1530
"model_path",
1631
[
@@ -328,13 +343,8 @@ def test_align_rows(rows, bbox, output):
328343
assert postprocess.align_rows(rows, bbox) == output
329344

330345

331-
def test_table_prediction_tesseract():
332-
table_model = tables.UnstructuredTableTransformerModel()
333-
from PIL import Image
334-
335-
table_model.initialize(model="microsoft/table-transformer-structure-recognition")
336-
img = Image.open("./sample-docs/table-multi-row-column-cells.png").convert("RGB")
337-
prediction = table_model.predict(img)
346+
def test_table_prediction_tesseract(table_transformer, example_image):
347+
prediction = table_transformer.predict(example_image)
338348
# assert rows spans two rows are detected
339349
assert '<table><thead><th rowspan="2">' in prediction
340350
# one of the safest rows to detect should be present
@@ -351,28 +361,24 @@ def test_table_prediction_tesseract():
351361

352362

353363
@pytest.mark.skipif(skip_outside_ci, reason="Skipping paddle test run outside of CI")
354-
def test_table_prediction_paddle(monkeypatch):
364+
def test_table_prediction_paddle(monkeypatch, example_image):
355365
monkeypatch.setenv("TABLE_OCR", "paddle")
356366
table_model = tables.UnstructuredTableTransformerModel()
357-
from PIL import Image
358367

359368
table_model.initialize(model="microsoft/table-transformer-structure-recognition")
360-
img = Image.open("./sample-docs/table-multi-row-column-cells.png").convert("RGB")
361-
prediction = table_model.predict(img)
369+
prediction = table_model.predict(example_image)
362370
# Note(yuming): lossen paddle table prediction output test since performance issue
363371
# and results are different in different platforms (i.e., gpu vs cpu)
364372
assert len(prediction)
365373

366374

367-
def test_table_prediction_invalid_table_ocr(monkeypatch):
375+
def test_table_prediction_invalid_table_ocr(monkeypatch, example_image):
368376
monkeypatch.setenv("TABLE_OCR", "invalid_table_ocr")
369377
with pytest.raises(ValueError):
370378
table_model = tables.UnstructuredTableTransformerModel()
371-
from PIL import Image
372379

373380
table_model.initialize(model="microsoft/table-transformer-structure-recognition")
374-
img = Image.open("./sample-docs/table-multi-row-column-cells.png").convert("RGB")
375-
_ = table_model.predict(img)
381+
_ = table_model.predict(example_image)
376382

377383

378384
def test_intersect():
@@ -581,3 +587,40 @@ def test_cells_to_html():
581587
"cols</td></tr><tr><td></td><td>sub cell 1</td><td>sub cell 2</td></tr></table>"
582588
)
583589
assert tables.cells_to_html(cells) == expected
590+
591+
592+
def test_padded_results_has_right_dimensions(table_transformer, example_image):
593+
str_class_name2idx = tables.get_class_map("structure")
594+
# a simpler mapping so we keep all structure in the returned objs below for test
595+
str_class_idx2name = {v: "table cell" for v in str_class_name2idx.values()}
596+
# pad size is no more than 10% of the original image so we can setup test below easier
597+
pad = int(min(example_image.size) / 10)
598+
599+
structure = table_transformer.get_structure(example_image, pad_for_structure_detection=pad)
600+
# boxes deteced OUTSIDE of the original image; this shouldn't happen but we want to make sure
601+
# the code handles it as expected
602+
structure["pred_boxes"][0][0, :2] = 0.5
603+
structure["pred_boxes"][0][0, 2:] = 1.0
604+
# mock a box we know are safly inside the original image with known positions
605+
width, height = example_image.size
606+
padded_width = width + pad * 2
607+
padded_height = height + pad * 2
608+
original = [1, 3, 101, 53]
609+
structure["pred_boxes"][0][1, :] = torch.tensor(
610+
[
611+
(51 + pad) / padded_width,
612+
(28 + pad) / padded_height,
613+
100 / padded_width,
614+
50 / padded_height,
615+
],
616+
)
617+
objs = tables.outputs_to_objects(structure, example_image.size, str_class_idx2name)
618+
np.testing.assert_almost_equal(objs[0]["bbox"], [-pad, -pad, width + pad, height + pad], 4)
619+
np.testing.assert_almost_equal(objs[1]["bbox"], original, 4)
620+
# a more strict test would be to constrain the actual detected boxes to be within the original
621+
# image but that requires the table transformer to behave in certain ways and do not
622+
# actually test the padding math; so here we use the relaxed condition
623+
for obj in objs[2:]:
624+
x1, y1, x2, y2 = obj["bbox"]
625+
assert max(x1, x2) < width + pad
626+
assert max(y1, y2) < height + pad
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.6.2" # pragma: no cover
1+
__version__ = "0.6.3" # pragma: no cover

unstructured_inference/models/tables.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,11 +220,11 @@ def outputs_to_objects(outputs, img_size, class_idx2name):
220220
pred_bboxes = outputs["pred_boxes"].detach().cpu()[0]
221221

222222
pad = outputs.get("pad_for_structure_detection", 0)
223-
scale_size = (img_size[0] + pad, img_size[1] + pad)
223+
scale_size = (img_size[0] + pad * 2, img_size[1] + pad * 2)
224224
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, scale_size)]
225225
# unshift the padding; padding effectively shifted the bounding boxes of structures in the
226226
# original image with half of the total pad
227-
shift_size = pad / 2
227+
shift_size = pad
228228

229229
objects = []
230230
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):

0 commit comments

Comments
 (0)