Skip to content

Commit 0eafc0f

Browse files
committed
Update schema to use is_extracted
1 parent 318065b commit 0eafc0f

File tree

3 files changed

+41
-56
lines changed

3 files changed

+41
-56
lines changed

unstructured_inference/constants.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,6 @@ class Source(Enum):
77
DETECTRON2_LP = "detectron2_lp"
88

99

10-
class TextSource(Enum):
11-
OCR = "ocr"
12-
EXTRACTED = "extracted"
13-
VLM = "vlm"
14-
15-
1610
class ElementType:
1711
PARAGRAPH = "Paragraph"
1812
IMAGE = "Image"

unstructured_inference/inference/elements.py

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import numpy as np
99

10-
from unstructured_inference.constants import Source, TextSource
10+
from unstructured_inference.constants import Source
1111
from unstructured_inference.math import safe_division
1212

1313

@@ -185,7 +185,7 @@ class TextRegion:
185185
bbox: Rectangle
186186
text: Optional[str] = None
187187
source: Optional[Source] = None
188-
text_source: Optional[TextSource] = None
188+
is_extracted: Optional[bool] = None
189189

190190
def __str__(self) -> str:
191191
return str(self.text)
@@ -199,13 +199,13 @@ def from_coords(
199199
y2: Union[int, float],
200200
text: Optional[str] = None,
201201
source: Optional[Source] = None,
202-
text_source: Optional[TextSource] = None,
202+
is_extracted: Optional[bool] = None,
203203
**kwargs,
204204
) -> TextRegion:
205205
"""Constructs a region from coordinates."""
206206
bbox = Rectangle(x1, y1, x2, y2)
207207

208-
return cls(text=text, source=source, text_source=text_source, bbox=bbox, **kwargs)
208+
return cls(text=text, source=source, is_extracted=is_extracted, bbox=bbox, **kwargs)
209209

210210

211211
@dataclass
@@ -214,27 +214,18 @@ class TextRegions:
214214
texts: np.ndarray = field(default_factory=lambda: np.array([]))
215215
sources: np.ndarray = field(default_factory=lambda: np.array([]))
216216
source: Source | None = None
217-
text_sources: np.ndarray = field(default_factory=lambda: np.array([]))
218-
text_source: TextSource | None = None
219-
_optional_array_attributes: list[str] = field(
220-
init=False, default_factory=lambda: ["texts", "sources", "text_sources"]
221-
)
222-
_scalar_to_array_mappings: dict[str, str] = field(
223-
init=False,
224-
default_factory=lambda: {
225-
"source": "sources",
226-
"text_source": "text_sources",
227-
},
228-
)
217+
is_extracted_array: np.ndarray = field(default_factory=lambda: np.array([]))
218+
is_extracted: bool | None = None
219+
_optional_array_attributes: list[str] = field(init=False, default_factory=lambda: ["texts", "sources", "is_extracted_array"])
220+
_scalar_to_array_mappings: dict[str, str] = field(init=False, default_factory=lambda: {
221+
"source": "sources",
222+
"is_extracted": "is_extracted_array",
223+
})
229224

230225
def __post_init__(self):
231226
element_size = self.element_coords.shape[0]
232227
for scalar, array in self._scalar_to_array_mappings.items():
233-
if (
234-
getattr(self, scalar) is not None
235-
and getattr(self, array).size == 0
236-
and element_size
237-
):
228+
if getattr(self, scalar) is not None and getattr(self, array).size == 0 and element_size:
238229
setattr(self, array, np.array([getattr(self, scalar)] * element_size))
239230
elif getattr(self, scalar) is None and getattr(self, array).size > 0:
240231
setattr(self, scalar, getattr(self, array)[0])
@@ -254,19 +245,19 @@ def slice(self, indices) -> TextRegions:
254245
element_coords=self.element_coords[indices],
255246
texts=self.texts[indices],
256247
sources=self.sources[indices],
257-
text_sources=self.text_sources[indices],
248+
is_extracted_array=self.is_extracted_array[indices],
258249
)
259250

260251
def iter_elements(self):
261252
"""iter text regions as one TextRegion per iteration; this returns a generator and has less
262253
memory impact than the as_list method"""
263-
for (x1, y1, x2, y2), text, source, text_source in zip(
254+
for (x1, y1, x2, y2), text, source, is_extracted in zip(
264255
self.element_coords,
265256
self.texts,
266257
self.sources,
267-
self.text_sources,
258+
self.is_extracted_array,
268259
):
269-
yield TextRegion.from_coords(x1, y1, x2, y2, text, source, text_source)
260+
yield TextRegion.from_coords(x1, y1, x2, y2, text, source, is_extracted)
270261

271262
def as_list(self):
272263
"""return a list of LayoutElement for backward compatibility"""
@@ -275,18 +266,18 @@ def as_list(self):
275266
@classmethod
276267
def from_list(cls, regions: list):
277268
"""create TextRegions from a list of TextRegion objects; the objects must have the same
278-
text_source"""
279-
coords, texts, sources, text_sources = [], [], [], []
269+
is_extracted"""
270+
coords, texts, sources, is_extracted_array = [], [], [], []
280271
for region in regions:
281272
coords.append((region.bbox.x1, region.bbox.y1, region.bbox.x2, region.bbox.y2))
282273
texts.append(region.text)
283274
sources.append(region.source)
284-
text_sources.append(region.text_source)
275+
is_extracted_array.append(region.is_extracted)
285276
return cls(
286277
element_coords=np.array(coords),
287278
texts=np.array(texts),
288279
sources=np.array(sources),
289-
text_sources=np.array(text_sources),
280+
is_extracted_array=np.array(is_extracted_array),
290281
)
291282

292283
def __len__(self):

unstructured_inference/inference/layoutelement.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from scipy.sparse.csgraph import connected_components
99

1010
from unstructured_inference.config import inference_config
11-
from unstructured_inference.constants import Source, TextSource
11+
from unstructured_inference.constants import Source
1212
from unstructured_inference.inference.elements import (
1313
Rectangle,
1414
TextRegion,
@@ -31,7 +31,7 @@ class LayoutElements(TextRegions):
3131
default_factory=lambda: [
3232
"texts",
3333
"sources",
34-
"text_sources",
34+
"is_extracted_array",
3535
"element_probs",
3636
"element_class_ids",
3737
"text_as_html",
@@ -42,7 +42,7 @@ class LayoutElements(TextRegions):
4242
init=False,
4343
default_factory=lambda: {
4444
"source": "sources",
45-
"text_source": "text_sources",
45+
"is_extracted": "is_extracted_array",
4646
},
4747
)
4848

@@ -66,7 +66,7 @@ def __eq__(self, other: object) -> bool:
6666
== [other.element_class_id_map[idx] for idx in other.element_class_ids]
6767
)
6868
and np.array_equal(self.sources[mask], other.sources[mask])
69-
and np.array_equal(self.text_sources[mask], other.text_sources[mask])
69+
and np.array_equal(self.is_extracted_array[mask], other.is_extracted_array[mask])
7070
and np.array_equal(self.text_as_html[mask], other.text_as_html[mask])
7171
and np.array_equal(self.table_as_cells[mask], other.table_as_cells[mask])
7272
)
@@ -79,7 +79,7 @@ def slice(self, indices) -> LayoutElements:
7979
return LayoutElements(
8080
element_coords=self.element_coords[indices],
8181
texts=self.texts[indices],
82-
text_sources=self.text_sources[indices],
82+
is_extracted_array=self.is_extracted_array[indices],
8383
sources=self.sources[indices],
8484
element_probs=self.element_probs[indices],
8585
element_class_ids=self.element_class_ids[indices],
@@ -91,15 +91,15 @@ def slice(self, indices) -> LayoutElements:
9191
@classmethod
9292
def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements:
9393
"""concatenate a sequence of LayoutElements in order as one LayoutElements"""
94-
coords, texts, probs, class_ids, sources, text_sources = [], [], [], [], [], []
94+
coords, texts, probs, class_ids, sources, is_extracted_array = [], [], [], [], [], []
9595
text_as_html, table_as_cells = [], []
9696
class_id_reverse_map: dict[str, int] = {}
9797
for group in groups:
9898
coords.append(group.element_coords)
9999
texts.append(group.texts)
100100
probs.append(group.element_probs)
101101
sources.append(group.sources)
102-
text_sources.append(group.text_sources)
102+
is_extracted_array.append(group.is_extracted_array)
103103
text_as_html.append(group.text_as_html)
104104
table_as_cells.append(group.table_as_cells)
105105

@@ -121,7 +121,7 @@ def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements:
121121
element_class_ids=np.concatenate(class_ids),
122122
element_class_id_map={v: k for k, v in class_id_reverse_map.items()},
123123
sources=np.concatenate(sources),
124-
text_sources=np.concatenate(text_sources),
124+
is_extracted_array=np.concatenate(is_extracted_array),
125125
text_as_html=np.concatenate(text_as_html),
126126
table_as_cells=np.concatenate(table_as_cells),
127127
)
@@ -135,7 +135,7 @@ def iter_elements(self):
135135
prob,
136136
class_id,
137137
source,
138-
text_source,
138+
is_extracted,
139139
text_as_html,
140140
table_as_cells,
141141
) in zip(
@@ -144,7 +144,7 @@ def iter_elements(self):
144144
self.element_probs,
145145
self.element_class_ids,
146146
self.sources,
147-
self.text_sources,
147+
self.is_extracted_array,
148148
self.text_as_html,
149149
self.table_as_cells,
150150
):
@@ -161,7 +161,7 @@ def iter_elements(self):
161161
),
162162
prob=None if np.isnan(prob) else prob,
163163
source=source,
164-
text_source=text_source,
164+
is_extracted=is_extracted,
165165
text_as_html=text_as_html,
166166
table_as_cells=table_as_cells,
167167
)
@@ -174,7 +174,7 @@ def from_list(cls, elements: list):
174174
coords = np.empty((len_ele, 4), dtype=float)
175175
# text and probs can be Nones so use lists first then convert into array to avoid them being
176176
# filled as nan
177-
texts, text_as_html, table_as_cells, sources, text_sources, class_probs = (
177+
texts, text_as_html, table_as_cells, sources, is_extracted_array, class_probs = (
178178
[],
179179
[],
180180
[],
@@ -188,7 +188,7 @@ def from_list(cls, elements: list):
188188
coords[i] = [element.bbox.x1, element.bbox.y1, element.bbox.x2, element.bbox.y2]
189189
texts.append(element.text)
190190
sources.append(element.source)
191-
text_sources.append(element.text_source)
191+
is_extracted_array.append(element.is_extracted)
192192
text_as_html.append(element.text_as_html)
193193
table_as_cells.append(element.table_as_cells)
194194
class_probs.append(element.prob)
@@ -204,7 +204,7 @@ def from_list(cls, elements: list):
204204
element_class_ids=class_ids,
205205
element_class_id_map=dict(zip(range(len(unique_ids)), unique_ids)),
206206
sources=np.array(sources),
207-
text_sources=np.array(text_sources),
207+
is_extracted_array=np.array(is_extracted_array),
208208
text_as_html=np.array(text_as_html),
209209
table_as_cells=np.array(table_as_cells),
210210
)
@@ -227,7 +227,7 @@ def to_dict(self) -> dict:
227227
"type": self.type,
228228
"prob": self.prob,
229229
"source": self.source,
230-
"text_source": self.text_source,
230+
"is_extracted": self.is_extracted,
231231
}
232232
return out_dict
233233

@@ -238,12 +238,12 @@ def from_region(cls, region: TextRegion):
238238
type = region.type if hasattr(region, "type") else None
239239
prob = region.prob if hasattr(region, "prob") else None
240240
source = region.source if hasattr(region, "source") else None
241-
text_source = region.text_source if hasattr(region, "text_source") else None
241+
is_extracted = region.is_extracted if hasattr(region, "is_extracted") else None
242242
return cls(
243243
bbox=region.bbox,
244244
text=text,
245245
source=source,
246-
text_source=text_source,
246+
is_extracted=is_extracted,
247247
type=type,
248248
prob=prob,
249249
)
@@ -257,7 +257,7 @@ def from_coords(
257257
y2: Union[int, float],
258258
text: Optional[str] = None,
259259
source: Optional[Source] = None,
260-
text_source: Optional[TextSource] = None,
260+
is_extracted: bool = None,
261261
type: Optional[str] = None,
262262
prob: Optional[float] = None,
263263
text_as_html: Optional[str] = None,
@@ -268,7 +268,7 @@ def from_coords(
268268
bbox = Rectangle(x1, y1, x2, y2)
269269
return cls(
270270
text=text,
271-
text_source=text_source,
271+
is_extracted=is_extracted,
272272
type=type,
273273
prob=prob,
274274
source=source,
@@ -427,7 +427,7 @@ def clean_layoutelements(elements: LayoutElements, subregion_threshold: float =
427427
final_attrs: dict[str, Any] = {
428428
"element_class_id_map": elements.element_class_id_map,
429429
}
430-
for attr in ("element_class_ids", "element_probs", "texts", "sources", "text_sources"):
430+
for attr in ("element_class_ids", "element_probs", "texts", "sources", "is_extracted_array"):
431431
if (original_attr := getattr(elements, attr)) is None:
432432
continue
433433
final_attrs[attr] = original_attr[sorted_by_area][mask][sorted_by_y1]
@@ -503,7 +503,7 @@ def clean_layoutelements_for_class(
503503

504504
final_coords = np.vstack([target_coords[mask], other_coords[other_mask]])
505505
final_attrs: dict[str, Any] = {"element_class_id_map": elements.element_class_id_map}
506-
for attr in ("element_class_ids", "element_probs", "texts", "sources", "text_sources"):
506+
for attr in ("element_class_ids", "element_probs", "texts", "sources", "is_extracted_array"):
507507
if (original_attr := getattr(elements, attr)) is None:
508508
continue
509509
final_attrs[attr] = np.concatenate(

0 commit comments

Comments
 (0)