Skip to content

Commit 85bcdc1

Browse files
authored
feat: add back source attribute for backward compatibility (#407)
This PR adds `source` back to `TextRegions` and `LayoutElements` for backward compatibility.
1 parent 655ea34 commit 85bcdc1

File tree

5 files changed

+26
-5
lines changed

5 files changed

+26
-5
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
## 0.8.6
2+
3+
* feat: add back `source` to `TextRegions` and `LayoutElements` for backward compatibility
4+
15
## 0.8.5
26

3-
* fix: remove `pdfplumber` but include `pdfminer-six==20240706` to update `pdfminer`
7+
* fix: remove `pdfplumber` but include `pdfminer-six==20240706` to update `pdfminer`
48

59
## 0.8.4
610

test_unstructured_inference/test_elements.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_layoutelements():
6161
element_coords=coords,
6262
element_class_ids=element_class_ids,
6363
element_class_id_map=class_map,
64-
sources=np.array(["yolox"] * len(element_class_ids)),
64+
source="yolox",
6565
)
6666

6767

@@ -441,20 +441,22 @@ def test_layoutelements_to_list_and_back(test_layoutelements):
441441
def test_layoutelements_from_list_no_elements():
442442
back = LayoutElements.from_list(elements=[])
443443
assert back.sources.size == 0
444+
assert back.source is None
444445
assert back.element_coords.size == 0
445446

446447

447448
def test_textregions_from_list_no_elements():
448449
back = TextRegions.from_list(regions=[])
449450
assert back.sources.size == 0
451+
assert back.source is None
450452
assert back.element_coords.size == 0
451453

452454

453455
def test_layoutelements_concatenate():
454456
layout1 = LayoutElements(
455457
element_coords=np.array([[0, 0, 1, 1], [1, 1, 2, 2]]),
456458
texts=np.array(["a", "two"]),
457-
sources=np.array(["yolox", "yolox"]),
459+
source="yolox",
458460
element_class_ids=np.array([0, 1]),
459461
element_class_id_map={0: "type0", 1: "type1"},
460462
)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.8.5" # pragma: no cover
1+
__version__ = "0.8.6" # pragma: no cover

unstructured_inference/inference/elements.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,21 @@ class TextRegions:
211211
element_coords: np.ndarray
212212
texts: np.ndarray = field(default_factory=lambda: np.array([]))
213213
sources: np.ndarray = field(default_factory=lambda: np.array([]))
214+
source: Source | None = None
214215

215216
def __post_init__(self):
216217
if self.texts.size == 0 and self.element_coords.size > 0:
217218
self.texts = np.array([None] * self.element_coords.shape[0])
218219

220+
# for backward compatibility; also allow to use one value to set sources for all regions
221+
if self.sources.size == 0 and self.element_coords.size > 0:
222+
self.sources = np.array([self.source] * self.element_coords.shape[0])
223+
elif self.source is None and self.sources.size:
224+
self.source = self.sources[0]
225+
226+
# we convert to float so data type is more consistent (e.g., None will be np.nan)
227+
self.element_coords = self.element_coords.astype(float)
228+
219229
def slice(self, indices) -> TextRegions:
220230
"""slice text regions based on indices"""
221231
return TextRegions(

unstructured_inference/inference/layoutelement.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,18 @@ def __post_init__(self):
4141
"element_probs",
4242
"element_class_ids",
4343
"texts",
44-
"sources",
4544
"text_as_html",
4645
"table_as_cells",
4746
):
4847
if getattr(self, attr).size == 0 and element_size:
4948
setattr(self, attr, np.array([None] * element_size))
5049

50+
# for backward compatibility; also allow to use one value to set sources for all regions
51+
if self.sources.size == 0 and self.element_coords.size > 0:
52+
self.sources = np.array([self.source] * self.element_coords.shape[0])
53+
elif self.source is None and self.sources.size:
54+
self.source = self.sources[0]
55+
5156
self.element_probs = self.element_probs.astype(float)
5257

5358
def __eq__(self, other: object) -> bool:

0 commit comments

Comments
 (0)