Skip to content

Commit f8bf10d

Browse files
authored
enhancement: add text sources (#449)
Added `IsExtracted` for tracking source of element text. <!-- CURSOR_SUMMARY --> --- > [!NOTE] > Introduces `IsExtracted` across `TextRegion`/`TextRegions`/`LayoutElement(s)` with refactored auto-initialization, updates slicing/concat/cleaning/equality to carry the flag and sources, and bumps version to 1.1.0. > > - **Core models**: > - Add `IsExtracted` enum; extend `TextRegion` with `is_extracted` and `from_coords(..., is_extracted=...)`. > - Refactor `TextRegions.__post_init__` to auto-initialize optional arrays from scalar fields (`source` → `sources`, `is_extracted` → `is_extracted_array`). > - Ensure slicing, iteration, and `from_list` preserve `sources` and `is_extracted_array`. > - **Layout elements**: > - Propagate `is_extracted` through `LayoutElement`/`LayoutElements` (`to_dict`, `from_region`, `from_coords`, `from_list`, `concatenate`, `iter_elements`, `slice`, `__eq__`). > - Include `is_extracted_array` in cleaning utilities (`clean_layoutelements*`) and concatenation outputs. > - **Tests**: > - Update/expand tests to validate `sources` and `is_extracted` propagation, slicing, `from_list`, and inheritance behavior. > - **Release**: > - Bump version to `1.1.0` and update `CHANGELOG.md`. > > <sup>Written by [Cursor Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit 03af14e. This will update automatically on new commits. Configure [here](https://cursor.com/dashboard?tab=bugbot).</sup> <!-- /CURSOR_SUMMARY -->
1 parent e6c8d19 commit f8bf10d

File tree

8 files changed

+299
-52
lines changed

8 files changed

+299
-52
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
## 1.1.0
2+
3+
* Enhancement: Add `TextSource` to track where the text of an element came from
4+
* Enhancement: Refactor `__post_init__` of `TextRegions` and `LayoutElement` slightly to automate initialization
5+
16
## 1.0.10
27

38
* Remove merging logic that's no longer used

test_unstructured_inference/inference/test_layout.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
EmbeddedTextRegion,
1414
ImageTextRegion,
1515
)
16+
from unstructured_inference.constants import IsExtracted
1617
from unstructured_inference.models.unstructuredmodel import (
1718
UnstructuredElementExtractionModel,
1819
UnstructuredObjectDetectionModel,
@@ -34,7 +35,7 @@ def mock_initial_layout():
3435
6,
3536
8,
3637
text="A very repetitive narrative. " * 10,
37-
source="Mock",
38+
is_extracted=IsExtracted.TRUE,
3839
)
3940

4041
title_block = EmbeddedTextRegion.from_coords(
@@ -43,7 +44,7 @@ def mock_initial_layout():
4344
3,
4445
4,
4546
text="A Catchy Title",
46-
source="Mock",
47+
is_extracted=IsExtracted.TRUE,
4748
)
4849

4950
return [text_block, title_block]

test_unstructured_inference/inference/test_layout_element.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from unstructured_inference.inference.layoutelement import LayoutElement, TextRegion
2+
from unstructured_inference.constants import IsExtracted, Source
23

34

4-
def test_layout_element_do_dict(mock_layout_element):
5+
def test_layout_element_to_dict(mock_layout_element):
56
expected = {
67
"coordinates": ((100, 100), (100, 300), (300, 300), (300, 100)),
78
"text": "Sample text",
9+
"is_extracted": None,
810
"type": "Text",
911
"prob": None,
1012
"source": None,
@@ -18,3 +20,31 @@ def test_layout_element_from_region(mock_rectangle):
1820
region = TextRegion(bbox=mock_rectangle)
1921

2022
assert LayoutElement.from_region(region) == expected
23+
24+
25+
def test_layoutelement_inheritance_works_correctly():
26+
"""Test that LayoutElement properly inherits from TextRegion without conflicts"""
27+
from unstructured_inference.inference.elements import TextRegion
28+
29+
# Create a TextRegion with both source and text_source
30+
region = TextRegion.from_coords(
31+
0, 0, 10, 10, text="test", source=Source.YOLOX, is_extracted=IsExtracted.TRUE
32+
)
33+
34+
# Convert to LayoutElement
35+
element = LayoutElement.from_region(region)
36+
37+
# Check that both properties are preserved
38+
assert element.source == Source.YOLOX, "LayoutElement should inherit source from TextRegion"
39+
assert (
40+
element.is_extracted == IsExtracted.TRUE
41+
), "LayoutElement should inherit is_extracted from TextRegion"
42+
43+
# Check that to_dict() works correctly
44+
d = element.to_dict()
45+
assert d["source"] == Source.YOLOX
46+
assert d["is_extracted"] == IsExtracted.TRUE
47+
48+
# Check that we can set source directly on LayoutElement
49+
element.source = Source.DETECTRON2_ONNX
50+
assert element.source == Source.DETECTRON2_ONNX

test_unstructured_inference/test_elements.py

Lines changed: 125 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
import numpy as np
66
import pytest
77

8+
from unstructured_inference.constants import IsExtracted, Source
89
from unstructured_inference.inference import elements
910
from unstructured_inference.inference.elements import (
1011
Rectangle,
12+
TextRegion,
1113
TextRegions,
1214
)
1315
from unstructured_inference.inference.layoutelement import (
@@ -56,7 +58,7 @@ def test_layoutelements():
5658
element_coords=coords,
5759
element_class_ids=element_class_ids,
5860
element_class_id_map=class_map,
59-
source="yolox",
61+
source=Source.YOLOX,
6062
)
6163

6264

@@ -307,7 +309,7 @@ def test_clean_layoutelements(test_layoutelements):
307309
elements[1].bbox.x2,
308310
elements[1].bbox.x2,
309311
) == (2, 2, 3, 3)
310-
assert elements[0].source == elements[1].source == "yolox"
312+
assert elements[0].source == elements[1].source == Source.YOLOX
311313

312314

313315
@pytest.mark.parametrize(
@@ -408,29 +410,34 @@ def test_layoutelements_from_list_no_elements():
408410

409411
def test_textregions_from_list_no_elements():
410412
back = TextRegions.from_list(regions=[])
411-
assert back.sources.size == 0
412-
assert back.source is None
413+
assert back.is_extracted_array.size == 0
414+
assert back.is_extracted is None
413415
assert back.element_coords.size == 0
414416

415417

416418
def test_layoutelements_concatenate():
417419
layout1 = LayoutElements(
418420
element_coords=np.array([[0, 0, 1, 1], [1, 1, 2, 2]]),
419421
texts=np.array(["a", "two"]),
420-
source="yolox",
422+
source=Source.YOLOX,
421423
element_class_ids=np.array([0, 1]),
422424
element_class_id_map={0: "type0", 1: "type1"},
423425
)
424426
layout2 = LayoutElements(
425427
element_coords=np.array([[10, 10, 2, 2], [20, 20, 1, 1]]),
426428
texts=np.array(["three", "4"]),
427-
sources=np.array(["ocr", "ocr"]),
429+
sources=np.array([Source.DETECTRON2_ONNX, Source.DETECTRON2_ONNX]),
428430
element_class_ids=np.array([0, 1]),
429431
element_class_id_map={0: "type1", 1: "type2"},
430432
)
431433
joint = LayoutElements.concatenate([layout1, layout2])
432434
assert joint.texts.tolist() == ["a", "two", "three", "4"]
433-
assert joint.sources.tolist() == ["yolox", "yolox", "ocr", "ocr"]
435+
assert [s.value for s in joint.sources.tolist()] == [
436+
"yolox",
437+
"yolox",
438+
"detectron2_onnx",
439+
"detectron2_onnx",
440+
]
434441
assert joint.element_class_ids.tolist() == [0, 1, 1, 2]
435442
assert joint.element_class_id_map == {0: "type0", 1: "type1", 2: "type2"}
436443

@@ -449,8 +456,8 @@ def test_layoutelements_concatenate():
449456
]
450457
),
451458
texts=np.array(["0", "1", "2", "3", "4"]),
452-
sources=np.array(["foo", "foo", "foo", "foo", "foo"], dtype="<U3"),
453-
source=np.str_("foo"),
459+
is_extracted_array=np.array([IsExtracted.TRUE] * 5),
460+
is_extracted=IsExtracted.TRUE,
454461
),
455462
LayoutElements(
456463
element_coords=np.array(
@@ -463,8 +470,10 @@ def test_layoutelements_concatenate():
463470
]
464471
),
465472
texts=np.array(["0", "1", "2", "3", "4"]),
466-
sources=np.array(["foo", "foo", "foo", "foo", "foo"], dtype="<U3"),
467-
source=np.str_("foo"),
473+
sources=np.array([Source.YOLOX] * 5),
474+
source=Source.YOLOX,
475+
is_extracted_array=np.array([] * 5),
476+
is_extracted=IsExtracted.TRUE,
468477
element_probs=np.array([0.0, 0.1, 0.2, 0.3, 0.4]),
469478
),
470479
],
@@ -479,3 +488,108 @@ def test_textregions_support_numpy_slicing(test_elements):
479488
)
480489
if isinstance(test_elements, LayoutElements):
481490
np.testing.assert_almost_equal(test_elements[1:4].element_probs, np.array([0.1, 0.2, 0.3]))
491+
492+
493+
def test_textregions_from_list_collects_sources():
494+
"""Test that TextRegions.from_list() collects both source and text_source from regions"""
495+
from unstructured_inference.inference.elements import TextRegion
496+
497+
regions = [
498+
TextRegion.from_coords(
499+
0, 0, 10, 10, text="first", source=Source.YOLOX, is_extracted=IsExtracted.TRUE
500+
),
501+
TextRegion.from_coords(
502+
10,
503+
10,
504+
20,
505+
20,
506+
text="second",
507+
source=Source.DETECTRON2_ONNX,
508+
is_extracted=IsExtracted.TRUE,
509+
),
510+
]
511+
512+
text_regions = TextRegions.from_list(regions)
513+
514+
# This should fail because from_list() doesn't collect sources
515+
assert text_regions.sources.size > 0, "sources array should not be empty"
516+
assert text_regions.sources[0] == Source.YOLOX
517+
assert text_regions.sources[1] == Source.DETECTRON2_ONNX
518+
519+
520+
def test_textregions_has_sources_field():
521+
"""Test that TextRegions has a sources field"""
522+
text_regions = TextRegions(element_coords=np.array([[0, 0, 10, 10]]))
523+
524+
# This should fail because TextRegions doesn't have a sources field
525+
assert hasattr(text_regions, "sources"), "TextRegions should have a sources field"
526+
assert hasattr(text_regions, "source"), "TextRegions should have a source field"
527+
528+
529+
def test_textregions_iter_elements_preserves_source():
530+
"""Test that TextRegions.iter_elements() preserves source property"""
531+
from unstructured_inference.inference.elements import TextRegion
532+
533+
regions = [
534+
TextRegion.from_coords(
535+
0, 0, 10, 10, text="first", source=Source.YOLOX, is_extracted=IsExtracted.TRUE
536+
),
537+
]
538+
text_regions = TextRegions.from_list(regions)
539+
540+
elements = list(text_regions.iter_elements())
541+
542+
# This should fail because iter_elements() doesn't pass source to TextRegion.from_coords()
543+
assert elements[0].source == Source.YOLOX, "iter_elements() should preserve source"
544+
545+
546+
def test_textregions_slice_preserves_sources():
547+
"""Test that TextRegions slicing preserves sources array"""
548+
from unstructured_inference.inference.elements import TextRegion
549+
550+
regions = [
551+
TextRegion.from_coords(
552+
0, 0, 10, 10, text="first", source=Source.YOLOX, is_extracted=IsExtracted.TRUE
553+
),
554+
TextRegion.from_coords(
555+
10,
556+
10,
557+
20,
558+
20,
559+
text="second",
560+
source=Source.DETECTRON2_ONNX,
561+
is_extracted=IsExtracted.TRUE,
562+
),
563+
]
564+
text_regions = TextRegions.from_list(regions)
565+
566+
sliced = text_regions[0:1]
567+
568+
# This should fail because slice() doesn't handle sources
569+
assert sliced.sources.size > 0, "Sliced TextRegions should have sources"
570+
assert sliced.sources[0] == Source.YOLOX
571+
assert sliced.is_extracted_array[0] is IsExtracted.TRUE
572+
573+
574+
def test_textregions_post_init_handles_sources():
575+
"""Test that TextRegions.__post_init__() handles sources array initialization"""
576+
# Create with source but no sources array
577+
text_regions = TextRegions(
578+
element_coords=np.array([[0, 0, 10, 10], [10, 10, 20, 20]]), source=Source.YOLOX
579+
)
580+
581+
# This should fail because __post_init__() doesn't handle sources
582+
assert text_regions.sources.size > 0, "sources should be initialized from source"
583+
assert text_regions.sources[0] == Source.YOLOX
584+
assert text_regions.sources[1] == Source.YOLOX
585+
586+
587+
def test_textregions_from_coords_accepts_source():
588+
"""Test that TextRegion.from_coords() accepts source parameter"""
589+
# This should fail because from_coords() doesn't accept source parameter
590+
region = TextRegion.from_coords(
591+
0, 0, 10, 10, text="test", source=Source.YOLOX, is_extracted=IsExtracted.TRUE
592+
)
593+
594+
assert region.source == Source.YOLOX
595+
assert region.is_extracted
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.10" # pragma: no cover
1+
__version__ = "1.1.0" # pragma: no cover

unstructured_inference/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ class Source(Enum):
77
DETECTRON2_LP = "detectron2_lp"
88

99

10+
class IsExtracted(Enum):
11+
TRUE = "true"
12+
FALSE = "false"
13+
PARTIAL = "partial"
14+
15+
1016
class ElementType:
1117
PARAGRAPH = "Paragraph"
1218
IMAGE = "Image"

0 commit comments

Comments
 (0)