@@ -30,10 +30,21 @@ class LayoutElements(TextRegions):
3030 element_probs : np .ndarray = field (default_factory = lambda : np .array ([]))
3131 element_class_ids : np .ndarray = field (default_factory = lambda : np .array ([]))
3232 element_class_id_map : dict [int , str ] = field (default_factory = dict )
33+ text_as_html : np .ndarray = field (default_factory = lambda : np .array ([]))
34+ table_as_cells : np .ndarray = field (default_factory = lambda : np .array ([]))
3335
3436 def __post_init__ (self ):
3537 element_size = self .element_coords .shape [0 ]
36- for attr in ("element_probs" , "element_class_ids" , "texts" ):
38+ # NOTE: maybe we should create an attribute _optional_attributes: list[str] to store this
39+ # list
40+ for attr in (
41+ "element_probs" ,
42+ "element_class_ids" ,
43+ "texts" ,
44+ "sources" ,
45+ "text_as_html" ,
46+ "table_as_cells" ,
47+ ):
3748 if getattr (self , attr ).size == 0 and element_size :
3849 setattr (self , attr , np .array ([None ] * element_size ))
3950
@@ -54,31 +65,37 @@ def __eq__(self, other: object) -> bool:
5465 [self .element_class_id_map [idx ] for idx in self .element_class_ids ]
5566 == [other .element_class_id_map [idx ] for idx in other .element_class_ids ]
5667 )
57- and self .source == other .source
68+ and np .array_equal (self .sources [mask ], other .sources [mask ])
69+ and np .array_equal (self .text_as_html [mask ], other .text_as_html [mask ])
70+ and np .array_equal (self .table_as_cells [mask ], other .table_as_cells [mask ])
5871 )
5972
6073 def slice (self , indices ) -> LayoutElements :
6174 """slice and return only selected indices"""
6275 return LayoutElements (
6376 element_coords = self .element_coords [indices ],
6477 texts = self .texts [indices ],
65- source = self .source ,
78+ sources = self .sources [ indices ] ,
6679 element_probs = self .element_probs [indices ],
6780 element_class_ids = self .element_class_ids [indices ],
6881 element_class_id_map = self .element_class_id_map ,
82+ text_as_html = self .text_as_html [indices ],
83+ table_as_cells = self .table_as_cells [indices ],
6984 )
7085
7186 @classmethod
7287 def concatenate (cls , groups : Iterable [LayoutElements ]) -> LayoutElements :
7388 """concatenate a sequence of LayoutElements in order as one LayoutElements"""
7489 coords , texts , probs , class_ids , sources = [], [], [], [], []
90+ text_as_html , table_as_cells = [], []
7591 class_id_reverse_map : dict [str , int ] = {}
7692 for group in groups :
7793 coords .append (group .element_coords )
7894 texts .append (group .texts )
7995 probs .append (group .element_probs )
80- if group .source :
81- sources .append (group .source )
96+ sources .append (group .sources )
97+ text_as_html .append (group .text_as_html )
98+ table_as_cells .append (group .table_as_cells )
8299
83100 idx = group .element_class_ids .copy ()
84101 if group .element_class_id_map :
@@ -97,13 +114,24 @@ def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements:
97114 element_probs = np .concatenate (probs ),
98115 element_class_ids = np .concatenate (class_ids ),
99116 element_class_id_map = {v : k for k , v in class_id_reverse_map .items ()},
100- source = sources [0 ] if sources else None ,
117+ sources = np .concatenate (sources ),
118+ text_as_html = np .concatenate (text_as_html ),
119+ table_as_cells = np .concatenate (table_as_cells ),
101120 )
102121
103- def as_list (self ):
104- """return a list of LayoutElement for backward compatibility"""
105- return [
106- LayoutElement .from_coords (
122+ def iter_elements (self ):
123+ """iter elements as one LayoutElement per iteration; this returns a generator and has less
124+ memory impact than the as_list method"""
125+ for (x1 , y1 , x2 , y2 ), text , prob , class_id , source , text_as_html , table_as_cells in zip (
126+ self .element_coords ,
127+ self .texts ,
128+ self .element_probs ,
129+ self .element_class_ids ,
130+ self .sources ,
131+ self .text_as_html ,
132+ self .table_as_cells ,
133+ ):
134+ yield LayoutElement .from_coords (
107135 x1 ,
108136 y1 ,
109137 x2 ,
@@ -115,15 +143,10 @@ def as_list(self):
115143 else None
116144 ),
117145 prob = None if np .isnan (prob ) else prob ,
118- source = self .source ,
146+ source = source ,
147+ text_as_html = text_as_html ,
148+ table_as_cells = table_as_cells ,
119149 )
120- for (x1 , y1 , x2 , y2 ), text , prob , class_id in zip (
121- self .element_coords ,
122- self .texts ,
123- self .element_probs ,
124- self .element_class_ids ,
125- )
126- ]
127150
128151 @classmethod
129152 def from_list (cls , elements : list ):
@@ -133,13 +156,15 @@ def from_list(cls, elements: list):
133156 coords = np .empty ((len_ele , 4 ), dtype = float )
134157 # text and probs can be Nones so use lists first then convert into array to avoid them being
135158 # filled as nan
136- texts = []
137- class_probs = []
159+ texts , text_as_html , table_as_cells , sources , class_probs = [], [], [], [], []
138160 class_types = np .empty ((len_ele ,), dtype = "object" )
139161
140162 for i , element in enumerate (elements ):
141163 coords [i ] = [element .bbox .x1 , element .bbox .y1 , element .bbox .x2 , element .bbox .y2 ]
142164 texts .append (element .text )
165+ sources .append (element .source )
166+ text_as_html .append (element .text_as_html )
167+ table_as_cells .append (element .table_as_cells )
143168 class_probs .append (element .prob )
144169 class_types [i ] = element .type or "None"
145170
@@ -152,7 +177,9 @@ def from_list(cls, elements: list):
152177 element_probs = np .array (class_probs ),
153178 element_class_ids = class_ids ,
154179 element_class_id_map = dict (zip (range (len (unique_ids )), unique_ids )),
155- source = elements [0 ].source if len_ele else None ,
180+ sources = np .array (sources ),
181+ text_as_html = np .array (text_as_html ),
182+ table_as_cells = np .array (table_as_cells ),
156183 )
157184
158185
@@ -162,6 +189,8 @@ class LayoutElement(TextRegion):
162189 prob : Optional [float ] = None
163190 image_path : Optional [str ] = None
164191 parent : Optional [LayoutElement ] = None
192+ text_as_html : Optional [str ] = None
193+ table_as_cells : Optional [str ] = None
165194
166195 def to_dict (self ) -> dict :
167196 """Converts the class instance to dictionary form."""
@@ -432,9 +461,8 @@ def clean_layoutelements(elements: LayoutElements, subregion_threshold: float =
432461
433462 final_attrs : dict [str , Any ] = {
434463 "element_class_id_map" : elements .element_class_id_map ,
435- "source" : elements .source ,
436464 }
437- for attr in ("element_class_ids" , "element_probs" , "texts" ):
465+ for attr in ("element_class_ids" , "element_probs" , "texts" , "sources" ):
438466 if (original_attr := getattr (elements , attr )) is None :
439467 continue
440468 final_attrs [attr ] = original_attr [sorted_by_area ][mask ][sorted_by_y1 ]
@@ -510,7 +538,7 @@ def clean_layoutelements_for_class(
510538
511539 final_coords = np .vstack ([target_coords [mask ], other_coords [other_mask ]])
512540 final_attrs : dict [str , Any ] = {"element_class_id_map" : elements .element_class_id_map }
513- for attr in ("element_class_ids" , "element_probs" , "texts" ):
541+ for attr in ("element_class_ids" , "element_probs" , "texts" , "sources" ):
514542 if (original_attr := getattr (elements , attr )) is None :
515543 continue
516544 final_attrs [attr ] = np .concatenate (
0 commit comments