Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 1.0.6

* Add slicing through indexing for vectorized elements

## 1.0.5

* feat: add thread lock to prevent racing condition when instantiating singletons
Expand Down
46 changes: 46 additions & 0 deletions test_unstructured_inference/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,3 +472,49 @@ def test_layoutelements_concatenate():
assert joint.sources.tolist() == ["yolox", "yolox", "ocr", "ocr"]
assert joint.element_class_ids.tolist() == [0, 1, 1, 2]
assert joint.element_class_id_map == {0: "type0", 1: "type1", 2: "type2"}


@pytest.mark.parametrize(
"test_elements",
[
TextRegions(
element_coords=np.array(
[
[0.0, 0.0, 1.0, 1.0],
[1.0, 0.0, 1.5, 1.0],
[2.0, 0.0, 2.5, 1.0],
[3.0, 0.0, 4.0, 1.0],
[4.0, 0.0, 5.0, 1.0],
]
),
texts=np.array(["0", "1", "2", "3", "4"]),
sources=np.array(["foo", "foo", "foo", "foo", "foo"], dtype="<U3"),
source=np.str_("foo"),
),
LayoutElements(
element_coords=np.array(
[
[0.0, 0.0, 1.0, 1.0],
[1.0, 0.0, 1.5, 1.0],
[2.0, 0.0, 2.5, 1.0],
[3.0, 0.0, 4.0, 1.0],
[4.0, 0.0, 5.0, 1.0],
]
),
texts=np.array(["0", "1", "2", "3", "4"]),
sources=np.array(["foo", "foo", "foo", "foo", "foo"], dtype="<U3"),
source=np.str_("foo"),
element_probs=np.array([0.0, 0.1, 0.2, 0.3, 0.4]),
),
],
)
def test_textregions_support_numpy_slicing(test_elements):
np.testing.assert_equal(test_elements[1:4].texts, np.array(["1", "2", "3"]))
np.testing.assert_equal(test_elements[0::2].texts, np.array(["0", "2", "4"]))
np.testing.assert_equal(test_elements[[1, 2, 4]].texts, np.array(["1", "2", "4"]))
np.testing.assert_equal(test_elements[np.array([1, 2, 4])].texts, np.array(["1", "2", "4"]))
np.testing.assert_equal(
test_elements[np.array([True, False, False, True, False])].texts, np.array(["0", "3"])
)
if isinstance(test_elements, LayoutElements):
np.testing.assert_almost_equal(test_elements[1:4].element_probs, np.array([0.1, 0.2, 0.3]))
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.5" # pragma: no cover
__version__ = "1.0.6" # pragma: no cover
3 changes: 3 additions & 0 deletions unstructured_inference/inference/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ def __post_init__(self):
# we convert to float so data type is more consistent (e.g., None will be np.nan)
self.element_coords = self.element_coords.astype(float)

def __getitem__(self, indices) -> TextRegions:
return self.slice(indices)

def slice(self, indices) -> TextRegions:
"""slice text regions based on indices"""
return TextRegions(
Expand Down
3 changes: 3 additions & 0 deletions unstructured_inference/inference/layoutelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def __eq__(self, other: object) -> bool:
and np.array_equal(self.table_as_cells[mask], other.table_as_cells[mask])
)

def __getitem__(self, indices):
return self.slice(indices)

def slice(self, indices) -> LayoutElements:
"""slice and return only selected indices"""
return LayoutElements(
Expand Down