Skip to content

Commit c53866b

Browse files
authored
fix: fix bugs in layoutelements (#393)
This PR fixes two bugs: - fix a type casting issue when subtracting a int array with a float. This popped up when testing with `unstructured`, and some sources of element coordinates are of `int` type. This PR adds a new unit test case for `int` coord type with the grouping function - fix element class id 0 becomes None bug: this happens when dumping `LayoutElements` as a list of `LayoutElement`. When an element class id is 0 the logic on main would treat it as no existing and use `None` as the type.
1 parent 4431fe5 commit c53866b

File tree

6 files changed

+76
-10
lines changed

6 files changed

+76
-10
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
## 0.7.41
2+
3+
* fix: fix incorrect type casting with higher versions of `numpy` when substracting a `float` from an `int` array
4+
* fix: fix a bug where class id 0 becomes class type `None` when calling `LayoutElements.as_list()`
5+
16
## 0.7.40
27

38
* fix: store probabilities with `float` data type instead of `int`

test_unstructured_inference/test_elements.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,10 @@ def test_minimal_containing_rect():
143143
assert rect2.is_in(big_rect)
144144

145145

146-
def test_partition_groups_from_regions(mock_embedded_text_regions):
146+
@pytest.mark.parametrize("coord_type", [int, float])
147+
def test_partition_groups_from_regions(mock_embedded_text_regions, coord_type):
147148
words = TextRegions.from_list(mock_embedded_text_regions)
149+
words.element_coords = words.element_coords.astype(coord_type)
148150
groups = partition_groups_from_regions(words)
149151
assert len(groups) == 1
150152
text = "".join(groups[-1].texts)
@@ -421,3 +423,14 @@ def test_clean_layoutelements_for_class(
421423
elements = clean_layoutelements_for_class(elements, element_class=class_to_filter)
422424
np.testing.assert_array_equal(elements.element_coords, expected_coords)
423425
np.testing.assert_array_equal(elements.element_class_ids, expected_ids)
426+
427+
428+
def test_layoutelements_to_list_and_back(test_layoutelements):
429+
back = LayoutElements.from_list(test_layoutelements.as_list())
430+
np.testing.assert_array_equal(test_layoutelements.element_coords, back.element_coords)
431+
np.testing.assert_array_equal(test_layoutelements.texts, back.texts)
432+
assert all(np.isnan(back.element_probs))
433+
assert [
434+
test_layoutelements.element_class_id_map[idx]
435+
for idx in test_layoutelements.element_class_ids
436+
] == [back.element_class_id_map[idx] for idx in back.element_class_ids]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.7.40" # pragma: no cover
1+
__version__ = "0.7.41" # pragma: no cover

unstructured_inference/inference/elements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def as_list(self):
237237
]
238238

239239
@classmethod
240-
def from_list(cls, regions: list[TextRegion]):
240+
def from_list(cls, regions: list):
241241
"""create TextRegions from a list of TextRegion objects; the objects must have the same
242242
source"""
243243
coords, texts = [], []

unstructured_inference/inference/layoutelement.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,34 @@
3232
class LayoutElements(TextRegions):
3333
element_probs: np.ndarray = field(default_factory=lambda: np.array([]))
3434
element_class_ids: np.ndarray = field(default_factory=lambda: np.array([]))
35-
element_class_id_map: dict[int, str] | None = None
35+
element_class_id_map: dict[int, str] = field(default_factory=dict)
3636

3737
def __post_init__(self):
38-
if self.element_probs is not None:
39-
self.element_probs = self.element_probs.astype(float)
4038
element_size = self.element_coords.shape[0]
4139
for attr in ("element_probs", "element_class_ids", "texts"):
4240
if getattr(self, attr).size == 0 and element_size:
4341
setattr(self, attr, np.array([None] * element_size))
4442

43+
self.element_probs = self.element_probs.astype(float)
44+
45+
def __eq__(self, other: object) -> bool:
46+
if not isinstance(other, LayoutElements):
47+
return NotImplemented
48+
49+
mask = ~np.isnan(self.element_probs)
50+
other_mask = ~np.isnan(other.element_probs)
51+
return (
52+
np.array_equal(self.element_coords, other.element_coords)
53+
and np.array_equal(self.texts, other.texts)
54+
and np.array_equal(mask, other_mask)
55+
and np.array_equal(self.element_probs[mask], other.element_probs[mask])
56+
and (
57+
[self.element_class_id_map[idx] for idx in self.element_class_ids]
58+
== [other.element_class_id_map[idx] for idx in other.element_class_ids]
59+
)
60+
and self.source == other.source
61+
)
62+
4563
def slice(self, indices) -> LayoutElements:
4664
"""slice and return only selected indices"""
4765
return LayoutElements(
@@ -85,10 +103,10 @@ def as_list(self):
85103
text=text,
86104
type=(
87105
self.element_class_id_map[class_id]
88-
if class_id and self.element_class_id_map
106+
if class_id is not None and self.element_class_id_map
89107
else None
90108
),
91-
prob=prob,
109+
prob=None if np.isnan(prob) else prob,
92110
source=self.source,
93111
)
94112
for (x1, y1, x2, y2), text, prob, class_id in zip(
@@ -99,6 +117,36 @@ def as_list(self):
99117
)
100118
]
101119

120+
@classmethod
121+
def from_list(cls, elements: list):
122+
"""create LayoutElements from a list of LayoutElement objects; the objects must have the
123+
same source"""
124+
len_ele = len(elements)
125+
coords = np.empty((len_ele, 4), dtype=float)
126+
# text and probs can be Nones so use lists first then convert into array to avoid them being
127+
# filled as nan
128+
texts = []
129+
class_probs = []
130+
class_types = np.empty((len_ele,), dtype="object")
131+
132+
for i, element in enumerate(elements):
133+
coords[i] = [element.bbox.x1, element.bbox.y1, element.bbox.x2, element.bbox.y2]
134+
texts.append(element.text)
135+
class_probs.append(element.prob)
136+
class_types[i] = element.type or "None"
137+
138+
unique_ids, class_ids = np.unique(class_types, return_inverse=True)
139+
unique_ids[unique_ids == "None"] = None
140+
141+
return cls(
142+
element_coords=coords,
143+
texts=np.array(texts),
144+
element_probs=np.array(class_probs),
145+
element_class_ids=class_ids,
146+
element_class_id_map=dict(zip(range(len(unique_ids)), unique_ids)),
147+
source=elements[0].source,
148+
)
149+
102150

103151
@dataclass
104152
class LayoutElement(TextRegion):
@@ -315,7 +363,7 @@ def partition_groups_from_regions(regions: TextRegions) -> List[TextRegions]:
315363
regions, each list corresponding with a group"""
316364
if len(regions) == 0:
317365
return []
318-
padded_coords = regions.element_coords.copy()
366+
padded_coords = regions.element_coords.copy().astype(float)
319367
v_pad = (regions.y2 - regions.y1) * inference_config.ELEMENTS_V_PADDING_COEF
320368
h_pad = (regions.x2 - regions.x1) * inference_config.ELEMENTS_H_PADDING_COEF
321369
padded_coords[:, 0] -= h_pad

unstructured_inference/models/yolox.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def image_processing(
136136
sorted_dets = dets[order]
137137

138138
return LayoutElements(
139-
element_coords=sorted_dets[:, :4],
139+
element_coords=sorted_dets[:, :4].astype(float),
140140
element_probs=sorted_dets[:, 4].astype(float),
141141
element_class_ids=sorted_dets[:, 5].astype(int),
142142
element_class_id_map=self.layout_classes,

0 commit comments

Comments
 (0)