3232class LayoutElements (TextRegions ):
3333 element_probs : np .ndarray = field (default_factory = lambda : np .array ([]))
3434 element_class_ids : np .ndarray = field (default_factory = lambda : np .array ([]))
35- element_class_id_map : dict [int , str ] | None = None
35+ element_class_id_map : dict [int , str ] = field ( default_factory = dict )
3636
3737 def __post_init__ (self ):
38- if self .element_probs is not None :
39- self .element_probs = self .element_probs .astype (float )
4038 element_size = self .element_coords .shape [0 ]
4139 for attr in ("element_probs" , "element_class_ids" , "texts" ):
4240 if getattr (self , attr ).size == 0 and element_size :
4341 setattr (self , attr , np .array ([None ] * element_size ))
4442
43+ self .element_probs = self .element_probs .astype (float )
44+
45+ def __eq__ (self , other : object ) -> bool :
46+ if not isinstance (other , LayoutElements ):
47+ return NotImplemented
48+
49+ mask = ~ np .isnan (self .element_probs )
50+ other_mask = ~ np .isnan (other .element_probs )
51+ return (
52+ np .array_equal (self .element_coords , other .element_coords )
53+ and np .array_equal (self .texts , other .texts )
54+ and np .array_equal (mask , other_mask )
55+ and np .array_equal (self .element_probs [mask ], other .element_probs [mask ])
56+ and (
57+ [self .element_class_id_map [idx ] for idx in self .element_class_ids ]
58+ == [other .element_class_id_map [idx ] for idx in other .element_class_ids ]
59+ )
60+ and self .source == other .source
61+ )
62+
4563 def slice (self , indices ) -> LayoutElements :
4664 """slice and return only selected indices"""
4765 return LayoutElements (
@@ -85,10 +103,10 @@ def as_list(self):
85103 text = text ,
86104 type = (
87105 self .element_class_id_map [class_id ]
88- if class_id and self .element_class_id_map
106+ if class_id is not None and self .element_class_id_map
89107 else None
90108 ),
91- prob = prob ,
109+ prob = None if np . isnan ( prob ) else prob ,
92110 source = self .source ,
93111 )
94112 for (x1 , y1 , x2 , y2 ), text , prob , class_id in zip (
@@ -99,6 +117,36 @@ def as_list(self):
99117 )
100118 ]
101119
120+ @classmethod
121+ def from_list (cls , elements : list ):
122+ """create LayoutElements from a list of LayoutElement objects; the objects must have the
123+ same source"""
124+ len_ele = len (elements )
125+ coords = np .empty ((len_ele , 4 ), dtype = float )
126+ # text and probs can be Nones so use lists first then convert into array to avoid them being
127+ # filled as nan
128+ texts = []
129+ class_probs = []
130+ class_types = np .empty ((len_ele ,), dtype = "object" )
131+
132+ for i , element in enumerate (elements ):
133+ coords [i ] = [element .bbox .x1 , element .bbox .y1 , element .bbox .x2 , element .bbox .y2 ]
134+ texts .append (element .text )
135+ class_probs .append (element .prob )
136+ class_types [i ] = element .type or "None"
137+
138+ unique_ids , class_ids = np .unique (class_types , return_inverse = True )
139+ unique_ids [unique_ids == "None" ] = None
140+
141+ return cls (
142+ element_coords = coords ,
143+ texts = np .array (texts ),
144+ element_probs = np .array (class_probs ),
145+ element_class_ids = class_ids ,
146+ element_class_id_map = dict (zip (range (len (unique_ids )), unique_ids )),
147+ source = elements [0 ].source ,
148+ )
149+
102150
103151@dataclass
104152class LayoutElement (TextRegion ):
@@ -315,7 +363,7 @@ def partition_groups_from_regions(regions: TextRegions) -> List[TextRegions]:
315363 regions, each list corresponding with a group"""
316364 if len (regions ) == 0 :
317365 return []
318- padded_coords = regions .element_coords .copy ()
366+ padded_coords = regions .element_coords .copy (). astype ( float )
319367 v_pad = (regions .y2 - regions .y1 ) * inference_config .ELEMENTS_V_PADDING_COEF
320368 h_pad = (regions .x2 - regions .x1 ) * inference_config .ELEMENTS_H_PADDING_COEF
321369 padded_coords [:, 0 ] -= h_pad
0 commit comments