Skip to content

Commit 4d0c20a

Browse files
authored
Feat/add more attributes to layoutelements (#404)
* feat: add `text_as_html` and `table_as_cells` to `LayoutElements` class as new attributes * feat: replace the single valueed `source` attribute from `TextRegions` and `LayoutElements` with an array attribute `sources`
1 parent ab25fb9 commit 4d0c20a

File tree

7 files changed

+86
-48
lines changed

7 files changed

+86
-48
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
## 0.8.4
2+
3+
* feat: add `text_as_html` and `table_as_cells` to `LayoutElements` class as new attributes
4+
* feat: replace the single valueed `source` attribute from `TextRegions` and `LayoutElements` with an array attribute `sources`
5+
16
## 0.8.3
27

38
* fix: removed `layoutelement.from_lp_textblock()` and related tests as it's not used

test_unstructured_inference/models/test_yolox.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ def test_layout_yolox_local_parsing_image():
3232
def test_layout_yolox_local_parsing_pdf():
3333
filename = os.path.join("sample-docs", "loremipsum.pdf")
3434
document_layout = process_file_with_model(filename, model_name="yolox")
35-
content = str(document_layout)
36-
assert "libero fringilla" in content
3735
assert len(document_layout.pages) == 1
3836
# NOTE(benjamin) The example sent to the test contains 5 text detections
3937
text_elements = [e for e in document_layout.pages[0].elements if e.type == "Text"]

test_unstructured_inference/test_elements.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_layoutelements():
6161
element_coords=coords,
6262
element_class_ids=element_class_ids,
6363
element_class_id_map=class_map,
64-
source="yolox",
64+
sources=np.array(["yolox"] * len(element_class_ids)),
6565
)
6666

6767

@@ -440,32 +440,33 @@ def test_layoutelements_to_list_and_back(test_layoutelements):
440440

441441
def test_layoutelements_from_list_no_elements():
442442
back = LayoutElements.from_list(elements=[])
443-
assert back.source is None
443+
assert back.sources.size == 0
444444
assert back.element_coords.size == 0
445445

446446

447447
def test_textregions_from_list_no_elements():
448448
back = TextRegions.from_list(regions=[])
449-
assert back.source is None
449+
assert back.sources.size == 0
450450
assert back.element_coords.size == 0
451451

452452

453453
def test_layoutelements_concatenate():
454454
layout1 = LayoutElements(
455455
element_coords=np.array([[0, 0, 1, 1], [1, 1, 2, 2]]),
456456
texts=np.array(["a", "two"]),
457-
source=None,
457+
sources=np.array(["yolox", "yolox"]),
458458
element_class_ids=np.array([0, 1]),
459459
element_class_id_map={0: "type0", 1: "type1"},
460460
)
461461
layout2 = LayoutElements(
462462
element_coords=np.array([[10, 10, 2, 2], [20, 20, 1, 1]]),
463463
texts=np.array(["three", "4"]),
464-
source=None,
464+
sources=np.array(["ocr", "ocr"]),
465465
element_class_ids=np.array([0, 1]),
466466
element_class_id_map={0: "type1", 1: "type2"},
467467
)
468468
joint = LayoutElements.concatenate([layout1, layout2])
469469
assert joint.texts.tolist() == ["a", "two", "three", "4"]
470+
assert joint.sources.tolist() == ["yolox", "yolox", "ocr", "ocr"]
470471
assert joint.element_class_ids.tolist() == [0, 1, 1, 2]
471472
assert joint.element_class_id_map == {0: "type0", 1: "type1", 2: "type2"}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.8.3" # pragma: no cover
1+
__version__ = "0.8.4" # pragma: no cover

unstructured_inference/inference/elements.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def from_coords(
210210
class TextRegions:
211211
element_coords: np.ndarray
212212
texts: np.ndarray = field(default_factory=lambda: np.array([]))
213-
source: Source | None = None
213+
sources: np.ndarray = field(default_factory=lambda: np.array([]))
214214

215215
def __post_init__(self):
216216
if self.texts.size == 0 and self.element_coords.size > 0:
@@ -221,31 +221,37 @@ def slice(self, indices) -> TextRegions:
221221
return TextRegions(
222222
element_coords=self.element_coords[indices],
223223
texts=self.texts[indices],
224-
source=self.source,
224+
sources=self.sources[indices],
225225
)
226226

227+
def iter_elements(self):
228+
"""iter text regions as one TextRegion per iteration; this returns a generator and has less
229+
memory impact than the as_list method"""
230+
for (x1, y1, x2, y2), text, source in zip(
231+
self.element_coords,
232+
self.texts,
233+
self.sources,
234+
):
235+
yield TextRegion.from_coords(x1, y1, x2, y2, text, source)
236+
227237
def as_list(self):
228-
"""return a list of TextRegion objects representing the data"""
229-
if self.texts is None:
230-
return [
231-
TextRegion.from_coords(x1, y1, x2, y2, None, self.source)
232-
for (x1, y1, x2, y2) in self.element_coords
233-
]
234-
return [
235-
TextRegion.from_coords(x1, y1, x2, y2, text, self.source)
236-
for (x1, y1, x2, y2), text in zip(self.element_coords, self.texts)
237-
]
238+
"""return a list of LayoutElement for backward compatibility"""
239+
return list(self.iter_elements())
238240

239241
@classmethod
240242
def from_list(cls, regions: list):
241243
"""create TextRegions from a list of TextRegion objects; the objects must have the same
242244
source"""
243-
coords, texts = [], []
245+
coords, texts, sources = [], [], []
244246
for region in regions:
245247
coords.append((region.bbox.x1, region.bbox.y1, region.bbox.x2, region.bbox.y2))
246248
texts.append(region.text)
247-
source = regions[0].source if regions else None
248-
return cls(element_coords=np.array(coords), texts=np.array(texts), source=source)
249+
sources.append(region.source)
250+
return cls(
251+
element_coords=np.array(coords),
252+
texts=np.array(texts),
253+
sources=np.array(sources),
254+
)
249255

250256
def __len__(self):
251257
return self.element_coords.shape[0]

unstructured_inference/inference/layoutelement.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,21 @@ class LayoutElements(TextRegions):
3030
element_probs: np.ndarray = field(default_factory=lambda: np.array([]))
3131
element_class_ids: np.ndarray = field(default_factory=lambda: np.array([]))
3232
element_class_id_map: dict[int, str] = field(default_factory=dict)
33+
text_as_html: np.ndarray = field(default_factory=lambda: np.array([]))
34+
table_as_cells: np.ndarray = field(default_factory=lambda: np.array([]))
3335

3436
def __post_init__(self):
3537
element_size = self.element_coords.shape[0]
36-
for attr in ("element_probs", "element_class_ids", "texts"):
38+
# NOTE: maybe we should create an attribute _optional_attributes: list[str] to store this
39+
# list
40+
for attr in (
41+
"element_probs",
42+
"element_class_ids",
43+
"texts",
44+
"sources",
45+
"text_as_html",
46+
"table_as_cells",
47+
):
3748
if getattr(self, attr).size == 0 and element_size:
3849
setattr(self, attr, np.array([None] * element_size))
3950

@@ -54,31 +65,37 @@ def __eq__(self, other: object) -> bool:
5465
[self.element_class_id_map[idx] for idx in self.element_class_ids]
5566
== [other.element_class_id_map[idx] for idx in other.element_class_ids]
5667
)
57-
and self.source == other.source
68+
and np.array_equal(self.sources[mask], other.sources[mask])
69+
and np.array_equal(self.text_as_html[mask], other.text_as_html[mask])
70+
and np.array_equal(self.table_as_cells[mask], other.table_as_cells[mask])
5871
)
5972

6073
def slice(self, indices) -> LayoutElements:
6174
"""slice and return only selected indices"""
6275
return LayoutElements(
6376
element_coords=self.element_coords[indices],
6477
texts=self.texts[indices],
65-
source=self.source,
78+
sources=self.sources[indices],
6679
element_probs=self.element_probs[indices],
6780
element_class_ids=self.element_class_ids[indices],
6881
element_class_id_map=self.element_class_id_map,
82+
text_as_html=self.text_as_html[indices],
83+
table_as_cells=self.table_as_cells[indices],
6984
)
7085

7186
@classmethod
7287
def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements:
7388
"""concatenate a sequence of LayoutElements in order as one LayoutElements"""
7489
coords, texts, probs, class_ids, sources = [], [], [], [], []
90+
text_as_html, table_as_cells = [], []
7591
class_id_reverse_map: dict[str, int] = {}
7692
for group in groups:
7793
coords.append(group.element_coords)
7894
texts.append(group.texts)
7995
probs.append(group.element_probs)
80-
if group.source:
81-
sources.append(group.source)
96+
sources.append(group.sources)
97+
text_as_html.append(group.text_as_html)
98+
table_as_cells.append(group.table_as_cells)
8299

83100
idx = group.element_class_ids.copy()
84101
if group.element_class_id_map:
@@ -97,13 +114,24 @@ def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements:
97114
element_probs=np.concatenate(probs),
98115
element_class_ids=np.concatenate(class_ids),
99116
element_class_id_map={v: k for k, v in class_id_reverse_map.items()},
100-
source=sources[0] if sources else None,
117+
sources=np.concatenate(sources),
118+
text_as_html=np.concatenate(text_as_html),
119+
table_as_cells=np.concatenate(table_as_cells),
101120
)
102121

103-
def as_list(self):
104-
"""return a list of LayoutElement for backward compatibility"""
105-
return [
106-
LayoutElement.from_coords(
122+
def iter_elements(self):
123+
"""iter elements as one LayoutElement per iteration; this returns a generator and has less
124+
memory impact than the as_list method"""
125+
for (x1, y1, x2, y2), text, prob, class_id, source, text_as_html, table_as_cells in zip(
126+
self.element_coords,
127+
self.texts,
128+
self.element_probs,
129+
self.element_class_ids,
130+
self.sources,
131+
self.text_as_html,
132+
self.table_as_cells,
133+
):
134+
yield LayoutElement.from_coords(
107135
x1,
108136
y1,
109137
x2,
@@ -115,15 +143,10 @@ def as_list(self):
115143
else None
116144
),
117145
prob=None if np.isnan(prob) else prob,
118-
source=self.source,
146+
source=source,
147+
text_as_html=text_as_html,
148+
table_as_cells=table_as_cells,
119149
)
120-
for (x1, y1, x2, y2), text, prob, class_id in zip(
121-
self.element_coords,
122-
self.texts,
123-
self.element_probs,
124-
self.element_class_ids,
125-
)
126-
]
127150

128151
@classmethod
129152
def from_list(cls, elements: list):
@@ -133,13 +156,15 @@ def from_list(cls, elements: list):
133156
coords = np.empty((len_ele, 4), dtype=float)
134157
# text and probs can be Nones so use lists first then convert into array to avoid them being
135158
# filled as nan
136-
texts = []
137-
class_probs = []
159+
texts, text_as_html, table_as_cells, sources, class_probs = [], [], [], [], []
138160
class_types = np.empty((len_ele,), dtype="object")
139161

140162
for i, element in enumerate(elements):
141163
coords[i] = [element.bbox.x1, element.bbox.y1, element.bbox.x2, element.bbox.y2]
142164
texts.append(element.text)
165+
sources.append(element.source)
166+
text_as_html.append(element.text_as_html)
167+
table_as_cells.append(element.table_as_cells)
143168
class_probs.append(element.prob)
144169
class_types[i] = element.type or "None"
145170

@@ -152,7 +177,9 @@ def from_list(cls, elements: list):
152177
element_probs=np.array(class_probs),
153178
element_class_ids=class_ids,
154179
element_class_id_map=dict(zip(range(len(unique_ids)), unique_ids)),
155-
source=elements[0].source if len_ele else None,
180+
sources=np.array(sources),
181+
text_as_html=np.array(text_as_html),
182+
table_as_cells=np.array(table_as_cells),
156183
)
157184

158185

@@ -162,6 +189,8 @@ class LayoutElement(TextRegion):
162189
prob: Optional[float] = None
163190
image_path: Optional[str] = None
164191
parent: Optional[LayoutElement] = None
192+
text_as_html: Optional[str] = None
193+
table_as_cells: Optional[str] = None
165194

166195
def to_dict(self) -> dict:
167196
"""Converts the class instance to dictionary form."""
@@ -432,9 +461,8 @@ def clean_layoutelements(elements: LayoutElements, subregion_threshold: float =
432461

433462
final_attrs: dict[str, Any] = {
434463
"element_class_id_map": elements.element_class_id_map,
435-
"source": elements.source,
436464
}
437-
for attr in ("element_class_ids", "element_probs", "texts"):
465+
for attr in ("element_class_ids", "element_probs", "texts", "sources"):
438466
if (original_attr := getattr(elements, attr)) is None:
439467
continue
440468
final_attrs[attr] = original_attr[sorted_by_area][mask][sorted_by_y1]
@@ -510,7 +538,7 @@ def clean_layoutelements_for_class(
510538

511539
final_coords = np.vstack([target_coords[mask], other_coords[other_mask]])
512540
final_attrs: dict[str, Any] = {"element_class_id_map": elements.element_class_id_map}
513-
for attr in ("element_class_ids", "element_probs", "texts"):
541+
for attr in ("element_class_ids", "element_probs", "texts", "sources"):
514542
if (original_attr := getattr(elements, attr)) is None:
515543
continue
516544
final_attrs[attr] = np.concatenate(

unstructured_inference/models/yolox.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def image_processing(
140140
element_probs=sorted_dets[:, 4].astype(float),
141141
element_class_ids=sorted_dets[:, 5].astype(int),
142142
element_class_id_map=self.layout_classes,
143-
source=Source.YOLOX,
143+
sources=np.array([Source.YOLOX] * sorted_dets.shape[0]),
144144
)
145145

146146

0 commit comments

Comments
 (0)