55import numpy as np
66import pytest
77
8+ from unstructured_inference .constants import IsExtracted , Source
89from unstructured_inference .inference import elements
910from unstructured_inference .inference .elements import (
1011 Rectangle ,
12+ TextRegion ,
1113 TextRegions ,
1214)
1315from unstructured_inference .inference .layoutelement import (
@@ -56,7 +58,7 @@ def test_layoutelements():
5658 element_coords = coords ,
5759 element_class_ids = element_class_ids ,
5860 element_class_id_map = class_map ,
59- source = "yolox" ,
61+ source = Source . YOLOX ,
6062 )
6163
6264
@@ -307,7 +309,7 @@ def test_clean_layoutelements(test_layoutelements):
307309 elements [1 ].bbox .x2 ,
308310 elements [1 ].bbox .x2 ,
309311 ) == (2 , 2 , 3 , 3 )
310- assert elements [0 ].source == elements [1 ].source == "yolox"
312+ assert elements [0 ].source == elements [1 ].source == Source . YOLOX
311313
312314
313315@pytest .mark .parametrize (
@@ -408,29 +410,34 @@ def test_layoutelements_from_list_no_elements():
408410
409411def test_textregions_from_list_no_elements ():
410412 back = TextRegions .from_list (regions = [])
411- assert back .sources .size == 0
412- assert back .source is None
413+ assert back .is_extracted_array .size == 0
414+ assert back .is_extracted is None
413415 assert back .element_coords .size == 0
414416
415417
416418def test_layoutelements_concatenate ():
417419 layout1 = LayoutElements (
418420 element_coords = np .array ([[0 , 0 , 1 , 1 ], [1 , 1 , 2 , 2 ]]),
419421 texts = np .array (["a" , "two" ]),
420- source = "yolox" ,
422+ source = Source . YOLOX ,
421423 element_class_ids = np .array ([0 , 1 ]),
422424 element_class_id_map = {0 : "type0" , 1 : "type1" },
423425 )
424426 layout2 = LayoutElements (
425427 element_coords = np .array ([[10 , 10 , 2 , 2 ], [20 , 20 , 1 , 1 ]]),
426428 texts = np .array (["three" , "4" ]),
427- sources = np .array (["ocr" , "ocr" ]),
429+ sources = np .array ([Source . DETECTRON2_ONNX , Source . DETECTRON2_ONNX ]),
428430 element_class_ids = np .array ([0 , 1 ]),
429431 element_class_id_map = {0 : "type1" , 1 : "type2" },
430432 )
431433 joint = LayoutElements .concatenate ([layout1 , layout2 ])
432434 assert joint .texts .tolist () == ["a" , "two" , "three" , "4" ]
433- assert joint .sources .tolist () == ["yolox" , "yolox" , "ocr" , "ocr" ]
435+ assert [s .value for s in joint .sources .tolist ()] == [
436+ "yolox" ,
437+ "yolox" ,
438+ "detectron2_onnx" ,
439+ "detectron2_onnx" ,
440+ ]
434441 assert joint .element_class_ids .tolist () == [0 , 1 , 1 , 2 ]
435442 assert joint .element_class_id_map == {0 : "type0" , 1 : "type1" , 2 : "type2" }
436443
@@ -449,8 +456,8 @@ def test_layoutelements_concatenate():
449456 ]
450457 ),
451458 texts = np .array (["0" , "1" , "2" , "3" , "4" ]),
452- sources = np .array (["foo" , "foo" , "foo" , "foo" , "foo" ], dtype = "<U3" ),
453- source = np . str_ ( "foo" ) ,
459+ is_extracted_array = np .array ([IsExtracted . TRUE ] * 5 ),
460+ is_extracted = IsExtracted . TRUE ,
454461 ),
455462 LayoutElements (
456463 element_coords = np .array (
@@ -463,8 +470,10 @@ def test_layoutelements_concatenate():
463470 ]
464471 ),
465472 texts = np .array (["0" , "1" , "2" , "3" , "4" ]),
466- sources = np .array (["foo" , "foo" , "foo" , "foo" , "foo" ], dtype = "<U3" ),
467- source = np .str_ ("foo" ),
473+ sources = np .array ([Source .YOLOX ] * 5 ),
474+ source = Source .YOLOX ,
475+ is_extracted_array = np .array ([] * 5 ),
476+ is_extracted = IsExtracted .TRUE ,
468477 element_probs = np .array ([0.0 , 0.1 , 0.2 , 0.3 , 0.4 ]),
469478 ),
470479 ],
@@ -479,3 +488,108 @@ def test_textregions_support_numpy_slicing(test_elements):
479488 )
480489 if isinstance (test_elements , LayoutElements ):
481490 np .testing .assert_almost_equal (test_elements [1 :4 ].element_probs , np .array ([0.1 , 0.2 , 0.3 ]))
491+
492+
493+ def test_textregions_from_list_collects_sources ():
494+ """Test that TextRegions.from_list() collects both source and text_source from regions"""
495+ from unstructured_inference .inference .elements import TextRegion
496+
497+ regions = [
498+ TextRegion .from_coords (
499+ 0 , 0 , 10 , 10 , text = "first" , source = Source .YOLOX , is_extracted = IsExtracted .TRUE
500+ ),
501+ TextRegion .from_coords (
502+ 10 ,
503+ 10 ,
504+ 20 ,
505+ 20 ,
506+ text = "second" ,
507+ source = Source .DETECTRON2_ONNX ,
508+ is_extracted = IsExtracted .TRUE ,
509+ ),
510+ ]
511+
512+ text_regions = TextRegions .from_list (regions )
513+
514+ # This should fail because from_list() doesn't collect sources
515+ assert text_regions .sources .size > 0 , "sources array should not be empty"
516+ assert text_regions .sources [0 ] == Source .YOLOX
517+ assert text_regions .sources [1 ] == Source .DETECTRON2_ONNX
518+
519+
520+ def test_textregions_has_sources_field ():
521+ """Test that TextRegions has a sources field"""
522+ text_regions = TextRegions (element_coords = np .array ([[0 , 0 , 10 , 10 ]]))
523+
524+ # This should fail because TextRegions doesn't have a sources field
525+ assert hasattr (text_regions , "sources" ), "TextRegions should have a sources field"
526+ assert hasattr (text_regions , "source" ), "TextRegions should have a source field"
527+
528+
529+ def test_textregions_iter_elements_preserves_source ():
530+ """Test that TextRegions.iter_elements() preserves source property"""
531+ from unstructured_inference .inference .elements import TextRegion
532+
533+ regions = [
534+ TextRegion .from_coords (
535+ 0 , 0 , 10 , 10 , text = "first" , source = Source .YOLOX , is_extracted = IsExtracted .TRUE
536+ ),
537+ ]
538+ text_regions = TextRegions .from_list (regions )
539+
540+ elements = list (text_regions .iter_elements ())
541+
542+ # This should fail because iter_elements() doesn't pass source to TextRegion.from_coords()
543+ assert elements [0 ].source == Source .YOLOX , "iter_elements() should preserve source"
544+
545+
546+ def test_textregions_slice_preserves_sources ():
547+ """Test that TextRegions slicing preserves sources array"""
548+ from unstructured_inference .inference .elements import TextRegion
549+
550+ regions = [
551+ TextRegion .from_coords (
552+ 0 , 0 , 10 , 10 , text = "first" , source = Source .YOLOX , is_extracted = IsExtracted .TRUE
553+ ),
554+ TextRegion .from_coords (
555+ 10 ,
556+ 10 ,
557+ 20 ,
558+ 20 ,
559+ text = "second" ,
560+ source = Source .DETECTRON2_ONNX ,
561+ is_extracted = IsExtracted .TRUE ,
562+ ),
563+ ]
564+ text_regions = TextRegions .from_list (regions )
565+
566+ sliced = text_regions [0 :1 ]
567+
568+ # This should fail because slice() doesn't handle sources
569+ assert sliced .sources .size > 0 , "Sliced TextRegions should have sources"
570+ assert sliced .sources [0 ] == Source .YOLOX
571+ assert sliced .is_extracted_array [0 ] is IsExtracted .TRUE
572+
573+
574+ def test_textregions_post_init_handles_sources ():
575+ """Test that TextRegions.__post_init__() handles sources array initialization"""
576+ # Create with source but no sources array
577+ text_regions = TextRegions (
578+ element_coords = np .array ([[0 , 0 , 10 , 10 ], [10 , 10 , 20 , 20 ]]), source = Source .YOLOX
579+ )
580+
581+ # This should fail because __post_init__() doesn't handle sources
582+ assert text_regions .sources .size > 0 , "sources should be initialized from source"
583+ assert text_regions .sources [0 ] == Source .YOLOX
584+ assert text_regions .sources [1 ] == Source .YOLOX
585+
586+
587+ def test_textregions_from_coords_accepts_source ():
588+ """Test that TextRegion.from_coords() accepts source parameter"""
589+ # This should fail because from_coords() doesn't accept source parameter
590+ region = TextRegion .from_coords (
591+ 0 , 0 , 10 , 10 , text = "test" , source = Source .YOLOX , is_extracted = IsExtracted .TRUE
592+ )
593+
594+ assert region .source == Source .YOLOX
595+ assert region .is_extracted
0 commit comments