Skip to content

Commit 7f82e52

Browse files
authored
fix missing source (#396)
This PR fixes a bug that was found when working with `unstructured`: https://github.com/Unstructured-IO/unstructured/actions/runs/11403980075/job/31732726752#step:6:1778 ## test - this PR should fix the failing test above in `unstructured` ci - this PR expands the test on `clean_layoutelements` to test `source` is kept ## note for release Since we have not tagged a release for 0.7.42 this PR opt to not increase version number and include itself as part of 0.7.42
1 parent 42eebd3 commit 7f82e52

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
## 0.7.42
22

3+
* fix: fix missing source after cleaning layout elements
34
* Remove chipper model
45

56
## 0.7.41

test_unstructured_inference/test_elements.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def test_layoutelements():
6161
element_coords=coords,
6262
element_class_ids=element_class_ids,
6363
element_class_id_map=class_map,
64+
source="yolox",
6465
)
6566

6667

@@ -345,6 +346,7 @@ def test_clean_layoutelements(test_layoutelements):
345346
elements[1].bbox.x2,
346347
elements[1].bbox.x2,
347348
) == (2, 2, 3, 3)
349+
assert elements[0].source == elements[1].source == "yolox"
348350

349351

350352
@pytest.mark.parametrize(

unstructured_inference/inference/layoutelement.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,15 @@ def slice(self, indices) -> LayoutElements:
7373
@classmethod
7474
def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements:
7575
"""concatenate a sequence of LayoutElements in order as one LayoutElements"""
76-
coords, texts, probs, class_ids = [], [], [], []
76+
coords, texts, probs, class_ids, sources = [], [], [], [], []
7777
class_id_map = {}
7878
for group in groups:
7979
coords.append(group.element_coords)
8080
texts.append(group.texts)
8181
probs.append(group.element_probs)
8282
class_ids.append(group.element_class_ids)
83+
if group.source:
84+
sources.append(group.source)
8385
if group.element_class_id_map:
8486
class_id_map.update(group.element_class_id_map)
8587
return cls(
@@ -88,7 +90,7 @@ def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements:
8890
element_probs=np.concatenate(probs),
8991
element_class_ids=np.concatenate(class_ids),
9092
element_class_id_map=class_id_map,
91-
source=group.source,
93+
source=sources[0] if sources else None,
9294
)
9395

9496
def as_list(self):
@@ -439,7 +441,10 @@ def clean_layoutelements(elements: LayoutElements, subregion_threshold: float =
439441
final_coords = sorted_coords[mask]
440442
sorted_by_y1 = np.argsort(final_coords[:, 1])
441443

442-
final_attrs: dict[str, Any] = {"element_class_id_map": elements.element_class_id_map}
444+
final_attrs: dict[str, Any] = {
445+
"element_class_id_map": elements.element_class_id_map,
446+
"source": elements.source,
447+
}
443448
for attr in ("element_class_ids", "element_probs", "texts"):
444449
if (original_attr := getattr(elements, attr)) is None:
445450
continue

0 commit comments

Comments
 (0)