1010import unstructured_inference .models .base as models
1111from unstructured_inference .inference import elements , layout , layoutelement
1212from unstructured_inference .models import detectron2 , tesseract
13+ from unstructured_inference .models .unstructuredmodel import (
14+ UnstructuredElementExtractionModel ,
15+ UnstructuredObjectDetectionModel ,
16+ )
1317
1418
1519@pytest .fixture ()
@@ -84,15 +88,15 @@ def test_get_page_elements(monkeypatch, mock_final_layout):
8488 number = 0 ,
8589 image = image ,
8690 layout = mock_final_layout ,
87- model = MockLayoutModel (mock_final_layout ),
91+ detection_model = MockLayoutModel (mock_final_layout ),
8892 )
8993
90- elements = page .get_elements_with_model (inplace = False )
94+ elements = page .get_elements_with_detection_model (inplace = False )
9195
9296 assert str (elements [0 ]) == "A Catchy Title"
9397 assert str (elements [1 ]).startswith ("A very repetitive narrative." )
9498
95- page .get_elements_with_model (inplace = True )
99+ page .get_elements_with_detection_model (inplace = True )
96100 assert elements == page .elements
97101
98102
@@ -130,9 +134,9 @@ def test_get_page_elements_with_ocr(monkeypatch):
130134 number = 0 ,
131135 image = image ,
132136 layout = doc_initial_layout ,
133- model = MockLayoutModel (doc_final_layout ),
137+ detection_model = MockLayoutModel (doc_final_layout ),
134138 )
135- page .get_elements_with_model ()
139+ page .get_elements_with_detection_model ()
136140
137141 assert str (page ) == "\n \n An Even Catchier Title"
138142
@@ -152,7 +156,7 @@ def test_read_pdf(monkeypatch, mock_initial_layout, mock_final_layout):
152156
153157 with patch .object (layout , "load_pdf" , return_value = (layouts , images )):
154158 model = layout .get_model ("detectron2_lp" )
155- doc = layout .DocumentLayout .from_file ("fake-file.pdf" , model = model )
159+ doc = layout .DocumentLayout .from_file ("fake-file.pdf" , detection_model = model )
156160
157161 assert str (doc ).startswith ("A Catchy Title" )
158162 assert str (doc ).count ("A Catchy Title" ) == 2 # Once for each page
@@ -172,7 +176,17 @@ def test_process_data_with_model(monkeypatch, mock_final_layout, model_name):
172176 "from_file" ,
173177 lambda * args , ** kwargs : layout .DocumentLayout .from_pages ([]),
174178 )
175- with patch ("builtins.open" , mock_open (read_data = b"000000" )), open ("" ) as fp :
179+
180+ def new_isinstance (obj , cls ):
181+ if type (obj ) == MockLayoutModel :
182+ return True
183+ else :
184+ return isinstance (obj , cls )
185+
186+ with patch ("builtins.open" , mock_open (read_data = b"000000" )), patch (
187+ "unstructured_inference.inference.layout.UnstructuredObjectDetectionModel" ,
188+ MockLayoutModel ,
189+ ), open ("" ) as fp :
176190 assert layout .process_data_with_model (fp , model_name = model_name )
177191
178192
@@ -305,7 +319,7 @@ def test_from_image_file(monkeypatch, mock_final_layout, filetype):
305319 def mock_get_elements (self , * args , ** kwargs ):
306320 self .elements = [mock_final_layout ]
307321
308- monkeypatch .setattr (layout .PageLayout , "get_elements_with_model " , mock_get_elements )
322+ monkeypatch .setattr (layout .PageLayout , "get_elements_with_detection_model " , mock_get_elements )
309323 elements = (
310324 layout .DocumentLayout .from_image_file (f"sample-docs/loremipsum.{ filetype } " )
311325 .pages [0 ]
@@ -342,7 +356,7 @@ def test_get_elements_from_layout(mock_initial_layout, idx):
342356
343357def test_page_numbers_in_page_objects ():
344358 with patch (
345- "unstructured_inference.inference.layout.PageLayout.get_elements_with_model " ,
359+ "unstructured_inference.inference.layout.PageLayout.get_elements_with_detection_model " ,
346360 ) as mock_get_elements :
347361 doc = layout .DocumentLayout .from_file ("sample-docs/layout-parser-paper.pdf" )
348362 mock_get_elements .assert_called ()
@@ -352,12 +366,16 @@ def test_page_numbers_in_page_objects():
352366@pytest .mark .parametrize (
353367 ("fixed_layouts" , "called_method" , "not_called_method" ),
354368 [
355- ([MockLayout ()], "get_elements_from_layout" , "get_elements_with_model " ),
356- (None , "get_elements_with_model " , "get_elements_from_layout" ),
369+ ([MockLayout ()], "get_elements_from_layout" , "get_elements_with_detection_model " ),
370+ (None , "get_elements_with_detection_model " , "get_elements_from_layout" ),
357371 ],
358372)
359373def test_from_file_fixed_layout (fixed_layouts , called_method , not_called_method ):
360- with patch .object (layout .PageLayout , "get_elements_with_model" , return_value = []), patch .object (
374+ with patch .object (
375+ layout .PageLayout ,
376+ "get_elements_with_detection_model" ,
377+ return_value = [],
378+ ), patch .object (
361379 layout .PageLayout ,
362380 "get_elements_from_layout" ,
363381 return_value = [],
@@ -524,7 +542,8 @@ def test_load_pdf_with_multicolumn_layout_and_ocr(filename="sample-docs/design-t
524542 assert element .text .startswith (test_snippets [i ])
525543
526544
527- def test_annotate ():
545+ @pytest .mark .parametrize ("colors" , ["red" , None ])
546+ def test_annotate (colors ):
528547 test_image_arr = np .ones ((100 , 100 , 3 ), dtype = "uint8" )
529548 image = Image .fromarray (test_image_arr )
530549 page = layout .PageLayout (number = 1 , image = image , layout = None )
@@ -533,7 +552,7 @@ def test_annotate():
533552 coords2 = (1 , 10 , 7 , 11 )
534553 rect2 = elements .Rectangle (* coords2 )
535554 page .elements = [rect1 , rect2 ]
536- annotated_image = page .annotate (colors = "red" )
555+ annotated_image = page .annotate (colors = colors )
537556 annotated_array = np .array (annotated_image )
538557 for x1 , y1 , x2 , y2 in [coords1 , coords2 ]:
539558 # Make sure the pixels on the edge of the box are red
@@ -595,8 +614,129 @@ def test_layout_order(ordering_layout):
595614 layout ,
596615 "load_pdf" ,
597616 lambda * args , ** kwargs : ([[]], [mock_image ]),
617+ ), patch .object (
618+ layout ,
619+ "UnstructuredObjectDetectionModel" ,
620+ object ,
598621 ):
599622 doc = layout .DocumentLayout .from_file ("sample-docs/layout-parser-paper.pdf" )
600623 page = doc .pages [0 ]
601624 for n , element in enumerate (page .elements ):
602625 assert element .text == str (n )
626+
627+
628+ def test_page_layout_raises_when_multiple_models_passed (mock_image , mock_initial_layout ):
629+ with pytest .raises (ValueError ):
630+ layout .PageLayout (
631+ 0 ,
632+ mock_image ,
633+ mock_initial_layout ,
634+ detection_model = "something" ,
635+ element_extraction_model = "something else" ,
636+ )
637+
638+
639+ class MockElementExtractionModel :
640+ def __call__ (self , x ):
641+ return [1 , 2 , 3 ]
642+
643+
644+ @pytest .mark .parametrize (("inplace" , "expected" ), [(True , None ), (False , [1 , 2 , 3 ])])
645+ def test_get_elements_using_image_extraction (mock_image , inplace , expected ):
646+ page = layout .PageLayout (
647+ 1 ,
648+ mock_image ,
649+ None ,
650+ element_extraction_model = MockElementExtractionModel (),
651+ )
652+ assert page .get_elements_using_image_extraction (inplace = inplace ) == expected
653+
654+
655+ def test_get_elements_using_image_extraction_raises_with_no_extraction_model (mock_image ):
656+ page = layout .PageLayout (1 , mock_image , None , element_extraction_model = None )
657+ with pytest .raises (ValueError ):
658+ page .get_elements_using_image_extraction ()
659+
660+
661+ def test_get_elements_with_detection_model_raises_with_wrong_default_model (monkeypatch ):
662+ monkeypatch .setattr (layout , "get_model" , lambda * x : MockLayoutModel (mock_final_layout ))
663+ page = layout .PageLayout (1 , mock_image , None )
664+ with pytest .raises (NotImplementedError ):
665+ page .get_elements_with_detection_model ()
666+
667+
668+ @pytest .mark .parametrize (
669+ (
670+ "detection_model" ,
671+ "element_extraction_model" ,
672+ "detection_model_called" ,
673+ "element_extraction_model_called" ,
674+ ),
675+ [(None , "asdf" , False , True ), ("asdf" , None , True , False )],
676+ )
677+ def test_from_image (
678+ mock_image ,
679+ detection_model ,
680+ element_extraction_model ,
681+ detection_model_called ,
682+ element_extraction_model_called ,
683+ ):
684+ with patch .object (
685+ layout .PageLayout ,
686+ "get_elements_using_image_extraction" ,
687+ ) as mock_image_extraction , patch .object (
688+ layout .PageLayout ,
689+ "get_elements_with_detection_model" ,
690+ ) as mock_detection :
691+ layout .PageLayout .from_image (
692+ mock_image ,
693+ detection_model = detection_model ,
694+ element_extraction_model = element_extraction_model ,
695+ )
696+ assert mock_image_extraction .called == element_extraction_model_called
697+ assert mock_detection .called == detection_model_called
698+
699+
700+ class MockUnstructuredElementExtractionModel (UnstructuredElementExtractionModel ):
701+ def initialize (self , * args , ** kwargs ):
702+ return super ().initialize (* args , ** kwargs )
703+
704+ def predict (self , x : Image ):
705+ return super ().predict (x )
706+
707+
708+ class MockUnstructuredDetectionModel (UnstructuredObjectDetectionModel ):
709+ def initialize (self , * args , ** kwargs ):
710+ return super ().initialize (* args , ** kwargs )
711+
712+ def predict (self , x : Image ):
713+ return super ().predict (x )
714+
715+
716+ @pytest .mark .parametrize (
717+ ("model_type" , "is_detection_model" ),
718+ [
719+ (MockUnstructuredElementExtractionModel , False ),
720+ (MockUnstructuredDetectionModel , True ),
721+ ],
722+ )
723+ def test_process_file_with_model_routing (monkeypatch , model_type , is_detection_model ):
724+ model = model_type ()
725+ monkeypatch .setattr (layout , "get_model" , lambda * x : model )
726+ with patch .object (layout .DocumentLayout , "from_file" ) as mock_from_file :
727+ layout .process_file_with_model ("asdf" , model_name = "fake" , is_image = False )
728+ if is_detection_model :
729+ detection_model = model
730+ element_extraction_model = None
731+ else :
732+ detection_model = None
733+ element_extraction_model = model
734+ mock_from_file .assert_called_once_with (
735+ "asdf" ,
736+ detection_model = detection_model ,
737+ element_extraction_model = element_extraction_model ,
738+ ocr_strategy = "auto" ,
739+ ocr_languages = "eng" ,
740+ fixed_layouts = None ,
741+ extract_tables = False ,
742+ )
0 commit comments