Skip to content

Commit bd61292

Browse files
authored
feat: large model (#132)
Added large model to accessible models. This required adding a new code path for this type of model. I compensated for the fact that elements from this model don't have location data by adding a new class LocationlessLayoutElement, but long term I think we should alter the TextRegion to have a bbox property instead of subclassing Rectangle.
1 parent ce1242c commit bd61292

File tree

10 files changed

+541
-40
lines changed

10 files changed

+541
-40
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.5.3
2+
3+
* Refactor for large model
4+
15
## 0.5.2
26

37
* Combine inferred elements with extracted elements

test_unstructured_inference/inference/test_layout.py

Lines changed: 154 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
import unstructured_inference.models.base as models
1111
from unstructured_inference.inference import elements, layout, layoutelement
1212
from unstructured_inference.models import detectron2, tesseract
13+
from unstructured_inference.models.unstructuredmodel import (
14+
UnstructuredElementExtractionModel,
15+
UnstructuredObjectDetectionModel,
16+
)
1317

1418

1519
@pytest.fixture()
@@ -84,15 +88,15 @@ def test_get_page_elements(monkeypatch, mock_final_layout):
8488
number=0,
8589
image=image,
8690
layout=mock_final_layout,
87-
model=MockLayoutModel(mock_final_layout),
91+
detection_model=MockLayoutModel(mock_final_layout),
8892
)
8993

90-
elements = page.get_elements_with_model(inplace=False)
94+
elements = page.get_elements_with_detection_model(inplace=False)
9195

9296
assert str(elements[0]) == "A Catchy Title"
9397
assert str(elements[1]).startswith("A very repetitive narrative.")
9498

95-
page.get_elements_with_model(inplace=True)
99+
page.get_elements_with_detection_model(inplace=True)
96100
assert elements == page.elements
97101

98102

@@ -130,9 +134,9 @@ def test_get_page_elements_with_ocr(monkeypatch):
130134
number=0,
131135
image=image,
132136
layout=doc_initial_layout,
133-
model=MockLayoutModel(doc_final_layout),
137+
detection_model=MockLayoutModel(doc_final_layout),
134138
)
135-
page.get_elements_with_model()
139+
page.get_elements_with_detection_model()
136140

137141
assert str(page) == "\n\nAn Even Catchier Title"
138142

@@ -152,7 +156,7 @@ def test_read_pdf(monkeypatch, mock_initial_layout, mock_final_layout):
152156

153157
with patch.object(layout, "load_pdf", return_value=(layouts, images)):
154158
model = layout.get_model("detectron2_lp")
155-
doc = layout.DocumentLayout.from_file("fake-file.pdf", model=model)
159+
doc = layout.DocumentLayout.from_file("fake-file.pdf", detection_model=model)
156160

157161
assert str(doc).startswith("A Catchy Title")
158162
assert str(doc).count("A Catchy Title") == 2 # Once for each page
@@ -172,7 +176,17 @@ def test_process_data_with_model(monkeypatch, mock_final_layout, model_name):
172176
"from_file",
173177
lambda *args, **kwargs: layout.DocumentLayout.from_pages([]),
174178
)
175-
with patch("builtins.open", mock_open(read_data=b"000000")), open("") as fp:
179+
180+
def new_isinstance(obj, cls):
181+
if type(obj) == MockLayoutModel:
182+
return True
183+
else:
184+
return isinstance(obj, cls)
185+
186+
with patch("builtins.open", mock_open(read_data=b"000000")), patch(
187+
"unstructured_inference.inference.layout.UnstructuredObjectDetectionModel",
188+
MockLayoutModel,
189+
), open("") as fp:
176190
assert layout.process_data_with_model(fp, model_name=model_name)
177191

178192

@@ -305,7 +319,7 @@ def test_from_image_file(monkeypatch, mock_final_layout, filetype):
305319
def mock_get_elements(self, *args, **kwargs):
306320
self.elements = [mock_final_layout]
307321

308-
monkeypatch.setattr(layout.PageLayout, "get_elements_with_model", mock_get_elements)
322+
monkeypatch.setattr(layout.PageLayout, "get_elements_with_detection_model", mock_get_elements)
309323
elements = (
310324
layout.DocumentLayout.from_image_file(f"sample-docs/loremipsum.{filetype}")
311325
.pages[0]
@@ -342,7 +356,7 @@ def test_get_elements_from_layout(mock_initial_layout, idx):
342356

343357
def test_page_numbers_in_page_objects():
344358
with patch(
345-
"unstructured_inference.inference.layout.PageLayout.get_elements_with_model",
359+
"unstructured_inference.inference.layout.PageLayout.get_elements_with_detection_model",
346360
) as mock_get_elements:
347361
doc = layout.DocumentLayout.from_file("sample-docs/layout-parser-paper.pdf")
348362
mock_get_elements.assert_called()
@@ -352,12 +366,16 @@ def test_page_numbers_in_page_objects():
352366
@pytest.mark.parametrize(
353367
("fixed_layouts", "called_method", "not_called_method"),
354368
[
355-
([MockLayout()], "get_elements_from_layout", "get_elements_with_model"),
356-
(None, "get_elements_with_model", "get_elements_from_layout"),
369+
([MockLayout()], "get_elements_from_layout", "get_elements_with_detection_model"),
370+
(None, "get_elements_with_detection_model", "get_elements_from_layout"),
357371
],
358372
)
359373
def test_from_file_fixed_layout(fixed_layouts, called_method, not_called_method):
360-
with patch.object(layout.PageLayout, "get_elements_with_model", return_value=[]), patch.object(
374+
with patch.object(
375+
layout.PageLayout,
376+
"get_elements_with_detection_model",
377+
return_value=[],
378+
), patch.object(
361379
layout.PageLayout,
362380
"get_elements_from_layout",
363381
return_value=[],
@@ -524,7 +542,8 @@ def test_load_pdf_with_multicolumn_layout_and_ocr(filename="sample-docs/design-t
524542
assert element.text.startswith(test_snippets[i])
525543

526544

527-
def test_annotate():
545+
@pytest.mark.parametrize("colors", ["red", None])
546+
def test_annotate(colors):
528547
test_image_arr = np.ones((100, 100, 3), dtype="uint8")
529548
image = Image.fromarray(test_image_arr)
530549
page = layout.PageLayout(number=1, image=image, layout=None)
@@ -533,7 +552,7 @@ def test_annotate():
533552
coords2 = (1, 10, 7, 11)
534553
rect2 = elements.Rectangle(*coords2)
535554
page.elements = [rect1, rect2]
536-
annotated_image = page.annotate(colors="red")
555+
annotated_image = page.annotate(colors=colors)
537556
annotated_array = np.array(annotated_image)
538557
for x1, y1, x2, y2 in [coords1, coords2]:
539558
# Make sure the pixels on the edge of the box are red
@@ -595,8 +614,129 @@ def test_layout_order(ordering_layout):
595614
layout,
596615
"load_pdf",
597616
lambda *args, **kwargs: ([[]], [mock_image]),
617+
), patch.object(
618+
layout,
619+
"UnstructuredObjectDetectionModel",
620+
object,
598621
):
599622
doc = layout.DocumentLayout.from_file("sample-docs/layout-parser-paper.pdf")
600623
page = doc.pages[0]
601624
for n, element in enumerate(page.elements):
602625
assert element.text == str(n)
626+
627+
628+
def test_page_layout_raises_when_multiple_models_passed(mock_image, mock_initial_layout):
629+
with pytest.raises(ValueError):
630+
layout.PageLayout(
631+
0,
632+
mock_image,
633+
mock_initial_layout,
634+
detection_model="something",
635+
element_extraction_model="something else",
636+
)
637+
638+
639+
class MockElementExtractionModel:
640+
def __call__(self, x):
641+
return [1, 2, 3]
642+
643+
644+
@pytest.mark.parametrize(("inplace", "expected"), [(True, None), (False, [1, 2, 3])])
645+
def test_get_elements_using_image_extraction(mock_image, inplace, expected):
646+
page = layout.PageLayout(
647+
1,
648+
mock_image,
649+
None,
650+
element_extraction_model=MockElementExtractionModel(),
651+
)
652+
assert page.get_elements_using_image_extraction(inplace=inplace) == expected
653+
654+
655+
def test_get_elements_using_image_extraction_raises_with_no_extraction_model(mock_image):
656+
page = layout.PageLayout(1, mock_image, None, element_extraction_model=None)
657+
with pytest.raises(ValueError):
658+
page.get_elements_using_image_extraction()
659+
660+
661+
def test_get_elements_with_detection_model_raises_with_wrong_default_model(monkeypatch):
662+
monkeypatch.setattr(layout, "get_model", lambda *x: MockLayoutModel(mock_final_layout))
663+
page = layout.PageLayout(1, mock_image, None)
664+
with pytest.raises(NotImplementedError):
665+
page.get_elements_with_detection_model()
666+
667+
668+
@pytest.mark.parametrize(
669+
(
670+
"detection_model",
671+
"element_extraction_model",
672+
"detection_model_called",
673+
"element_extraction_model_called",
674+
),
675+
[(None, "asdf", False, True), ("asdf", None, True, False)],
676+
)
677+
def test_from_image(
678+
mock_image,
679+
detection_model,
680+
element_extraction_model,
681+
detection_model_called,
682+
element_extraction_model_called,
683+
):
684+
with patch.object(
685+
layout.PageLayout,
686+
"get_elements_using_image_extraction",
687+
) as mock_image_extraction, patch.object(
688+
layout.PageLayout,
689+
"get_elements_with_detection_model",
690+
) as mock_detection:
691+
layout.PageLayout.from_image(
692+
mock_image,
693+
detection_model=detection_model,
694+
element_extraction_model=element_extraction_model,
695+
)
696+
assert mock_image_extraction.called == element_extraction_model_called
697+
assert mock_detection.called == detection_model_called
698+
699+
700+
class MockUnstructuredElementExtractionModel(UnstructuredElementExtractionModel):
701+
def initialize(self, *args, **kwargs):
702+
return super().initialize(*args, **kwargs)
703+
704+
def predict(self, x: Image):
705+
return super().predict(x)
706+
707+
708+
class MockUnstructuredDetectionModel(UnstructuredObjectDetectionModel):
709+
def initialize(self, *args, **kwargs):
710+
return super().initialize(*args, **kwargs)
711+
712+
def predict(self, x: Image):
713+
return super().predict(x)
714+
715+
716+
@pytest.mark.parametrize(
717+
("model_type", "is_detection_model"),
718+
[
719+
(MockUnstructuredElementExtractionModel, False),
720+
(MockUnstructuredDetectionModel, True),
721+
],
722+
)
723+
def test_process_file_with_model_routing(monkeypatch, model_type, is_detection_model):
724+
model = model_type()
725+
monkeypatch.setattr(layout, "get_model", lambda *x: model)
726+
with patch.object(layout.DocumentLayout, "from_file") as mock_from_file:
727+
layout.process_file_with_model("asdf", model_name="fake", is_image=False)
728+
if is_detection_model:
729+
detection_model = model
730+
element_extraction_model = None
731+
else:
732+
detection_model = None
733+
element_extraction_model = model
734+
mock_from_file.assert_called_once_with(
735+
"asdf",
736+
detection_model=detection_model,
737+
element_extraction_model=element_extraction_model,
738+
ocr_strategy="auto",
739+
ocr_languages="eng",
740+
fixed_layouts=None,
741+
extract_tables=False,
742+
)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from unittest import mock
2+
3+
import pytest
4+
from PIL import Image
5+
6+
from unstructured_inference.models import largemodel
7+
8+
9+
def test_initialize():
10+
with mock.patch.object(
11+
largemodel.AutoTokenizer,
12+
"from_pretrained",
13+
) as mock_tokenizer, mock.patch.object(
14+
largemodel,
15+
"DonutProcessor",
16+
) as mock_donut_processor, mock.patch.object(
17+
largemodel,
18+
"DonutImageProcessor",
19+
) as mock_donut_image_processor, mock.patch.object(
20+
largemodel.VisionEncoderDecoderModel,
21+
"from_pretrained",
22+
) as mock_vision_encoder_decoder_model:
23+
model = largemodel.UnstructuredLargeModel()
24+
model.initialize("", "", "")
25+
mock_tokenizer.assert_called_once()
26+
mock_donut_processor.assert_called_once()
27+
mock_donut_image_processor.assert_called_once()
28+
mock_vision_encoder_decoder_model.assert_called_once()
29+
30+
31+
class MockToList:
32+
def tolist(self):
33+
return [[5, 4, 3, 2, 1]]
34+
35+
36+
class MockModel:
37+
def generate(*args, **kwargs):
38+
return MockToList()
39+
40+
41+
def mock_initialize(self, *arg, **kwargs):
42+
self.model = MockModel()
43+
self.processor = mock.MagicMock()
44+
45+
46+
def test_predict_tokens():
47+
with mock.patch.object(largemodel.UnstructuredLargeModel, "initialize", mock_initialize):
48+
model = largemodel.UnstructuredLargeModel()
49+
model.initialize()
50+
with open("sample-docs/loremipsum.png", "rb") as fp:
51+
im = Image.open(fp)
52+
tokens = model.predict_tokens(im)
53+
assert tokens[1:-1] == [5, 4, 3, 2, 1]
54+
55+
56+
@pytest.mark.parametrize(
57+
("decoded_str", "expected_classes"),
58+
[
59+
(
60+
"<s_Title>Hi buddy!</s_Title><s_Text>There is some text here.</s_Text>",
61+
["Title", "Text"],
62+
),
63+
("<s_Title>Hi buddy!</s_Title><s_Text>There is some text here.", ["Title", "Text"]),
64+
],
65+
)
66+
def test_postprocess(decoded_str, expected_classes):
67+
with mock.patch.object(largemodel.UnstructuredLargeModel, "initialize", mock_initialize):
68+
pass
69+
model = largemodel.UnstructuredLargeModel()
70+
tokenizer_model = "xlm-roberta-large"
71+
pre_trained_model = "nielsr/donut-base"
72+
model.initialize(tokenizer_model, pre_trained_model, None)
73+
74+
tokens = model.tokenizer.encode(decoded_str)
75+
out = model.postprocess(tokens)
76+
assert len(out) == 2
77+
element1, element2 = out
78+
79+
assert [element1.type, element2.type] == expected_classes
80+
81+
82+
def test_predict():
83+
with mock.patch.object(
84+
largemodel.UnstructuredLargeModel,
85+
"predict_tokens",
86+
) as mock_predict_tokens, mock.patch.object(
87+
largemodel.UnstructuredLargeModel,
88+
"postprocess",
89+
) as mock_postprocess:
90+
model = largemodel.UnstructuredLargeModel()
91+
model.predict("hello")
92+
mock_predict_tokens.assert_called_once()
93+
mock_postprocess.assert_called_once()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.5.2" # pragma: no cover
1+
__version__ = "0.5.3" # pragma: no cover

0 commit comments

Comments
 (0)