Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## 1.1.0

* Enhancement: Add `TextSource` to track where the text of an element came from
* Enhancement: Refactor `__post_init__` of `TextRegions` and `LayoutElement` slightly to automate initialization

## 1.0.10

* Remove merging logic that's no longer used
Expand Down
5 changes: 3 additions & 2 deletions test_unstructured_inference/inference/test_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
EmbeddedTextRegion,
ImageTextRegion,
)
from unstructured_inference.constants import IsExtracted
from unstructured_inference.models.unstructuredmodel import (
UnstructuredElementExtractionModel,
UnstructuredObjectDetectionModel,
Expand All @@ -34,7 +35,7 @@ def mock_initial_layout():
6,
8,
text="A very repetitive narrative. " * 10,
source="Mock",
is_extracted=IsExtracted.TRUE,
)

title_block = EmbeddedTextRegion.from_coords(
Expand All @@ -43,7 +44,7 @@ def mock_initial_layout():
3,
4,
text="A Catchy Title",
source="Mock",
is_extracted=IsExtracted.TRUE,
)

return [text_block, title_block]
Expand Down
32 changes: 31 additions & 1 deletion test_unstructured_inference/inference/test_layout_element.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from unstructured_inference.inference.layoutelement import LayoutElement, TextRegion
from unstructured_inference.constants import IsExtracted, Source


def test_layout_element_do_dict(mock_layout_element):
def test_layout_element_to_dict(mock_layout_element):
expected = {
"coordinates": ((100, 100), (100, 300), (300, 300), (300, 100)),
"text": "Sample text",
"is_extracted": None,
"type": "Text",
"prob": None,
"source": None,
Expand All @@ -18,3 +20,31 @@ def test_layout_element_from_region(mock_rectangle):
region = TextRegion(bbox=mock_rectangle)

assert LayoutElement.from_region(region) == expected


def test_layoutelement_inheritance_works_correctly():
"""Test that LayoutElement properly inherits from TextRegion without conflicts"""
from unstructured_inference.inference.elements import TextRegion

# Create a TextRegion with both source and text_source
region = TextRegion.from_coords(
0, 0, 10, 10, text="test", source=Source.YOLOX, is_extracted=IsExtracted.TRUE
)

# Convert to LayoutElement
element = LayoutElement.from_region(region)

# Check that both properties are preserved
assert element.source == Source.YOLOX, "LayoutElement should inherit source from TextRegion"
assert (
element.is_extracted == IsExtracted.TRUE
), "LayoutElement should inherit is_extracted from TextRegion"

# Check that to_dict() works correctly
d = element.to_dict()
assert d["source"] == Source.YOLOX
assert d["is_extracted"] == IsExtracted.TRUE

# Check that we can set source directly on LayoutElement
element.source = Source.DETECTRON2_ONNX
assert element.source == Source.DETECTRON2_ONNX
136 changes: 125 additions & 11 deletions test_unstructured_inference/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import numpy as np
import pytest

from unstructured_inference.constants import IsExtracted, Source
from unstructured_inference.inference import elements
from unstructured_inference.inference.elements import (
Rectangle,
TextRegion,
TextRegions,
)
from unstructured_inference.inference.layoutelement import (
Expand Down Expand Up @@ -56,7 +58,7 @@ def test_layoutelements():
element_coords=coords,
element_class_ids=element_class_ids,
element_class_id_map=class_map,
source="yolox",
source=Source.YOLOX,
)


Expand Down Expand Up @@ -307,7 +309,7 @@ def test_clean_layoutelements(test_layoutelements):
elements[1].bbox.x2,
elements[1].bbox.x2,
) == (2, 2, 3, 3)
assert elements[0].source == elements[1].source == "yolox"
assert elements[0].source == elements[1].source == Source.YOLOX


@pytest.mark.parametrize(
Expand Down Expand Up @@ -408,29 +410,34 @@ def test_layoutelements_from_list_no_elements():

def test_textregions_from_list_no_elements():
back = TextRegions.from_list(regions=[])
assert back.sources.size == 0
assert back.source is None
assert back.is_extracted_array.size == 0
assert back.is_extracted is None
assert back.element_coords.size == 0


def test_layoutelements_concatenate():
layout1 = LayoutElements(
element_coords=np.array([[0, 0, 1, 1], [1, 1, 2, 2]]),
texts=np.array(["a", "two"]),
source="yolox",
source=Source.YOLOX,
element_class_ids=np.array([0, 1]),
element_class_id_map={0: "type0", 1: "type1"},
)
layout2 = LayoutElements(
element_coords=np.array([[10, 10, 2, 2], [20, 20, 1, 1]]),
texts=np.array(["three", "4"]),
sources=np.array(["ocr", "ocr"]),
sources=np.array([Source.DETECTRON2_ONNX, Source.DETECTRON2_ONNX]),
element_class_ids=np.array([0, 1]),
element_class_id_map={0: "type1", 1: "type2"},
)
joint = LayoutElements.concatenate([layout1, layout2])
assert joint.texts.tolist() == ["a", "two", "three", "4"]
assert joint.sources.tolist() == ["yolox", "yolox", "ocr", "ocr"]
assert [s.value for s in joint.sources.tolist()] == [
"yolox",
"yolox",
"detectron2_onnx",
"detectron2_onnx",
]
assert joint.element_class_ids.tolist() == [0, 1, 1, 2]
assert joint.element_class_id_map == {0: "type0", 1: "type1", 2: "type2"}

Expand All @@ -449,8 +456,8 @@ def test_layoutelements_concatenate():
]
),
texts=np.array(["0", "1", "2", "3", "4"]),
sources=np.array(["foo", "foo", "foo", "foo", "foo"], dtype="<U3"),
source=np.str_("foo"),
is_extracted_array=np.array([IsExtracted.TRUE] * 5),
is_extracted=IsExtracted.TRUE,
),
LayoutElements(
element_coords=np.array(
Expand All @@ -463,8 +470,10 @@ def test_layoutelements_concatenate():
]
),
texts=np.array(["0", "1", "2", "3", "4"]),
sources=np.array(["foo", "foo", "foo", "foo", "foo"], dtype="<U3"),
source=np.str_("foo"),
sources=np.array([Source.YOLOX] * 5),
source=Source.YOLOX,
is_extracted_array=np.array([] * 5),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Empty array length mismatch in test fixture

The is_extracted_array in the LayoutElements test fixture is initialized with np.array([] * 5). This expression creates an empty array, which is inconsistent with the 5-element element_coords and the TextRegions example above it. This likely isn't the intended behavior and could lead to array length mismatches.

Fix in Cursor Fix in Web

is_extracted=IsExtracted.TRUE,
element_probs=np.array([0.0, 0.1, 0.2, 0.3, 0.4]),
),
],
Expand All @@ -479,3 +488,108 @@ def test_textregions_support_numpy_slicing(test_elements):
)
if isinstance(test_elements, LayoutElements):
np.testing.assert_almost_equal(test_elements[1:4].element_probs, np.array([0.1, 0.2, 0.3]))


def test_textregions_from_list_collects_sources():
"""Test that TextRegions.from_list() collects both source and text_source from regions"""
from unstructured_inference.inference.elements import TextRegion

regions = [
TextRegion.from_coords(
0, 0, 10, 10, text="first", source=Source.YOLOX, is_extracted=IsExtracted.TRUE
),
TextRegion.from_coords(
10,
10,
20,
20,
text="second",
source=Source.DETECTRON2_ONNX,
is_extracted=IsExtracted.TRUE,
),
]

text_regions = TextRegions.from_list(regions)

# This should fail because from_list() doesn't collect sources
assert text_regions.sources.size > 0, "sources array should not be empty"
assert text_regions.sources[0] == Source.YOLOX
assert text_regions.sources[1] == Source.DETECTRON2_ONNX


def test_textregions_has_sources_field():
"""Test that TextRegions has a sources field"""
text_regions = TextRegions(element_coords=np.array([[0, 0, 10, 10]]))

# This should fail because TextRegions doesn't have a sources field
assert hasattr(text_regions, "sources"), "TextRegions should have a sources field"
assert hasattr(text_regions, "source"), "TextRegions should have a source field"


def test_textregions_iter_elements_preserves_source():
"""Test that TextRegions.iter_elements() preserves source property"""
from unstructured_inference.inference.elements import TextRegion

regions = [
TextRegion.from_coords(
0, 0, 10, 10, text="first", source=Source.YOLOX, is_extracted=IsExtracted.TRUE
),
]
text_regions = TextRegions.from_list(regions)

elements = list(text_regions.iter_elements())

# This should fail because iter_elements() doesn't pass source to TextRegion.from_coords()
assert elements[0].source == Source.YOLOX, "iter_elements() should preserve source"


def test_textregions_slice_preserves_sources():
"""Test that TextRegions slicing preserves sources array"""
from unstructured_inference.inference.elements import TextRegion

regions = [
TextRegion.from_coords(
0, 0, 10, 10, text="first", source=Source.YOLOX, is_extracted=IsExtracted.TRUE
),
TextRegion.from_coords(
10,
10,
20,
20,
text="second",
source=Source.DETECTRON2_ONNX,
is_extracted=IsExtracted.TRUE,
),
]
text_regions = TextRegions.from_list(regions)

sliced = text_regions[0:1]

# This should fail because slice() doesn't handle sources
assert sliced.sources.size > 0, "Sliced TextRegions should have sources"
assert sliced.sources[0] == Source.YOLOX
assert sliced.is_extracted_array[0] is IsExtracted.TRUE


def test_textregions_post_init_handles_sources():
"""Test that TextRegions.__post_init__() handles sources array initialization"""
# Create with source but no sources array
text_regions = TextRegions(
element_coords=np.array([[0, 0, 10, 10], [10, 10, 20, 20]]), source=Source.YOLOX
)

# This should fail because __post_init__() doesn't handle sources
assert text_regions.sources.size > 0, "sources should be initialized from source"
assert text_regions.sources[0] == Source.YOLOX
assert text_regions.sources[1] == Source.YOLOX


def test_textregions_from_coords_accepts_source():
"""Test that TextRegion.from_coords() accepts source parameter"""
# This should fail because from_coords() doesn't accept source parameter
region = TextRegion.from_coords(
0, 0, 10, 10, text="test", source=Source.YOLOX, is_extracted=IsExtracted.TRUE
)

assert region.source == Source.YOLOX
assert region.is_extracted
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.10" # pragma: no cover
__version__ = "1.1.0" # pragma: no cover
6 changes: 6 additions & 0 deletions unstructured_inference/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ class Source(Enum):
DETECTRON2_LP = "detectron2_lp"


class IsExtracted(Enum):
TRUE = "true"
FALSE = "false"
PARTIAL = "partial"


class ElementType:
PARAGRAPH = "Paragraph"
IMAGE = "Image"
Expand Down
Loading