Skip to content

Commit 9c3f644

Browse files
authored
fix - table transformer predictions are filtered if confidence is below threshold (#338)
Add usage of table transformer related thresholds. The predictions with low confidence score are filtered out
1 parent 4304c83 commit 9c3f644

File tree

5 files changed

+94
-24
lines changed

5 files changed

+94
-24
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
## 0.7.29
2+
3+
* fix: table transformer predictions are now removed if confidence is below threshold
4+
5+
16
## 0.7.28
27

38
* feat: allow table transformer agent to return table prediction in not parsed format

test_unstructured_inference/models/test_tables.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import unstructured_inference.models.table_postprocess as postprocess
1212
from unstructured_inference.models import tables
13+
from unstructured_inference.models.tables import apply_thresholds_on_objects
1314

1415
skip_outside_ci = os.getenv("CI", "").lower() in {"", "false", "f", "0"}
1516

@@ -977,6 +978,55 @@ def test_table_prediction_with_no_ocr_tokens(table_transformer, example_image):
977978
table_transformer.predict(example_image)
978979

979980

981+
@pytest.mark.parametrize(
982+
("thresholds", "expected_object_number"),
983+
[
984+
({"0": 0.5}, 1),
985+
({"0": 0.1}, 3),
986+
({"0": 0.9}, 0),
987+
],
988+
)
989+
def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_and_threshold(
990+
thresholds, expected_object_number
991+
):
992+
objects = [
993+
{"label": "0", "score": 0.2},
994+
{"label": "0", "score": 0.4},
995+
{"label": "0", "score": 0.55},
996+
]
997+
assert len(apply_thresholds_on_objects(objects, thresholds)) == expected_object_number
998+
999+
1000+
@pytest.mark.parametrize(
1001+
("thresholds", "expected_object_number"),
1002+
[
1003+
({"0": 0.5, "1": 0.1}, 4),
1004+
({"0": 0.1, "1": 0.9}, 3),
1005+
({"0": 0.9, "1": 0.5}, 1),
1006+
],
1007+
)
1008+
def test_objects_are_filtered_based_on_class_thresholds_when_two_classes(
1009+
thresholds, expected_object_number
1010+
):
1011+
objects = [
1012+
{"label": "0", "score": 0.2},
1013+
{"label": "0", "score": 0.4},
1014+
{"label": "0", "score": 0.55},
1015+
{"label": "1", "score": 0.2},
1016+
{"label": "1", "score": 0.4},
1017+
{"label": "1", "score": 0.55},
1018+
]
1019+
assert len(apply_thresholds_on_objects(objects, thresholds)) == expected_object_number
1020+
1021+
1022+
def test_objects_filtering_when_missing_threshold():
1023+
class_name = "class_name"
1024+
objects = [{"label": class_name, "score": 0.2}]
1025+
thresholds = {"1": 0.5}
1026+
with pytest.raises(KeyError, match=class_name):
1027+
apply_thresholds_on_objects(objects, thresholds)
1028+
1029+
9801030
def test_intersect():
9811031
a = postprocess.Rect()
9821032
b = postprocess.Rect([1, 2, 3, 4])
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.7.28" # pragma: no cover
1+
__version__ = "0.7.29" # pragma: no cover

unstructured_inference/models/table_postprocess.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -80,24 +80,6 @@ def apply_threshold(objects, threshold):
8080
return [obj for obj in objects if obj["score"] >= threshold]
8181

8282

83-
# def apply_class_thresholds(bboxes, labels, scores, class_names, class_thresholds):
84-
# """
85-
# Filter out bounding boxes whose confidence is below the confidence threshold for
86-
# its associated class label.
87-
# """
88-
# # Apply class-specific thresholds
89-
# indices_above_threshold = [
90-
# idx
91-
# for idx, (score, label) in enumerate(zip(scores, labels))
92-
# if score >= class_thresholds[class_names[label]]
93-
# ]
94-
# bboxes = [bboxes[idx] for idx in indices_above_threshold]
95-
# scores = [scores[idx] for idx in indices_above_threshold]
96-
# labels = [labels[idx] for idx in indices_above_threshold]
97-
98-
# return bboxes, scores, labels
99-
100-
10183
def refine_rows(rows, tokens, score_threshold):
10284
"""
10385
Apply operations to the detected rows, such as

unstructured_inference/models/tables.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
import xml.etree.ElementTree as ET
44
from collections import defaultdict
55
from pathlib import Path
6-
from typing import Dict, List, Optional, Union
6+
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
77

88
import cv2
99
import numpy as np
1010
import torch
1111
from PIL import Image as PILImage
1212
from transformers import DetrImageProcessor, TableTransformerForObjectDetection
13+
from transformers.models.table_transformer.modeling_table_transformer import (
14+
TableTransformerObjectDetectionOutput,
15+
)
1316

1417
from unstructured_inference.config import inference_config
1518
from unstructured_inference.inference.layoutelement import table_cells_to_dataframe
@@ -172,18 +175,22 @@ def recognize(outputs: dict, img: PILImage.Image, tokens: list):
172175
"""Recognize table elements."""
173176
str_class_name2idx = get_class_map("structure")
174177
str_class_idx2name = {v: k for k, v in str_class_name2idx.items()}
175-
str_class_thresholds = structure_class_thresholds
178+
class_thresholds = structure_class_thresholds
176179

177180
# Post-process detected objects, assign class labels
178181
objects = outputs_to_objects(outputs, img.size, str_class_idx2name)
179-
182+
high_confidence_objects = apply_thresholds_on_objects(objects, class_thresholds)
180183
# Further process the detected objects so they correspond to a consistent table
181-
tables_structure = objects_to_structures(objects, tokens, str_class_thresholds)
184+
tables_structure = objects_to_structures(high_confidence_objects, tokens, class_thresholds)
182185
# Enumerate all table cells: grid cells and spanning cells
183186
return [structure_to_cells(structure, tokens)[0] for structure in tables_structure]
184187

185188

186-
def outputs_to_objects(outputs, img_size, class_idx2name):
189+
def outputs_to_objects(
190+
outputs: TableTransformerObjectDetectionOutput,
191+
img_size: tuple[int, int],
192+
class_idx2name: Mapping[int, str],
193+
):
187194
"""Output table element types."""
188195
m = outputs["logits"].softmax(-1).max(-1)
189196
pred_labels = list(m.indices.detach().cpu().numpy())[0]
@@ -212,6 +219,32 @@ def outputs_to_objects(outputs, img_size, class_idx2name):
212219
return objects
213220

214221

222+
def apply_thresholds_on_objects(
223+
objects: Sequence[Mapping[str, Any]], thresholds: Mapping[str, float]
224+
) -> Sequence[Mapping[str, Any]]:
225+
"""
226+
Filters predicted objects which the confidence scores below the thresholds
227+
228+
Args:
229+
objects: Sequence of mappings for example:
230+
[
231+
{
232+
"label": "table row",
233+
"score": 0.55,
234+
"bbox": [...],
235+
},
236+
...,
237+
]
238+
thresholds: Mapping from labels to thresholds
239+
240+
Returns:
241+
Filtered list of objects
242+
243+
"""
244+
objects = [obj for obj in objects if obj["score"] >= thresholds[obj["label"]]]
245+
return objects
246+
247+
215248
# for output bounding box post-processing
216249
def box_cxcywh_to_xyxy(x):
217250
"""Convert rectangle format from center-x, center-y, width, height to

0 commit comments

Comments
 (0)