55from logging import getLogger
66from typing import Any
77
8- from kensho_kenverters .constants import (
8+ from .constants import (
99 CATEGORY_KEY ,
1010 DOCUMENT_CATEGORY_KEY ,
1111 ELEMENT_TITLE_CONTENT_CATEGORIES ,
1212 EMPTY_STRING ,
13+ FIGURE_EXTRACTED_TABLE_KEY ,
1314 LOCATIONS_KEY ,
14- TABLE_CONTENT_CATEGORIES ,
1515 TABLE_KEY ,
1616 TEXT_KEY ,
1717 AnnotationType ,
1818 ContentCategory ,
1919 TableType ,
2020)
21- from kensho_kenverters .extract_output_models import ContentModel , LocationModel
22- from kensho_kenverters .utils import load_output_to_pydantic
21+ from .extract_output_models import AnnotationModel , ContentModel , LocationModel
22+ from .output_to_tables import (
23+ build_content_grid_from_figure_extracted_table_cell_annotations ,
24+ get_table_uid_to_annotations_mapping ,
25+ get_table_uid_to_cells_mapping ,
26+ )
27+ from .utils import load_output_to_pydantic
2328
2429logger = getLogger (__name__ )
2530
@@ -110,14 +115,18 @@ def _create_segment(
110115 content : ContentModel ,
111116 uid_to_index : dict [str , tuple [int , int ]],
112117 uid_to_span : dict [str , tuple [int , int ]],
118+ figure_extracted_table_uid_to_cell_annotations : dict [str , list [AnnotationModel ]],
113119) -> dict [str , Any ]:
114120 """Create segment dictionary from the content, and if applicable its matching table cells."""
115121 segment : dict [str , Any ] = {}
116122 # DOCUMENT is just a head node
117123 if content .type == DOCUMENT_CATEGORY_KEY :
118124 return {}
119125 # For tables, use table cell structures read above
120- elif content .type in TABLE_CONTENT_CATEGORIES :
126+ elif content .type in (
127+ ContentCategory .TABLE .value ,
128+ ContentCategory .TABLE_OF_CONTENTS .value ,
129+ ):
121130 # Construct the table from cells
122131 table_cells = content .children
123132 # Drop tables with no cells
@@ -132,7 +141,24 @@ def _create_segment(
132141 TABLE_KEY : table ,
133142 TEXT_KEY : table_to_markdown (table ),
134143 }
135- elif content .type == ContentCategory .TABLE_CELL .value :
144+ elif content .type == ContentCategory .FIGURE_EXTRACTED_TABLE .value :
145+ figure_extracted_table = (
146+ build_content_grid_from_figure_extracted_table_cell_annotations (
147+ figure_extracted_table_uid_to_cell_annotations [content .uid ]
148+ )
149+ )
150+ # Drop tables with length 0
151+ if len (figure_extracted_table ) == 0 :
152+ return {}
153+ segment = {
154+ CATEGORY_KEY : content .type .lower (),
155+ FIGURE_EXTRACTED_TABLE_KEY : figure_extracted_table ,
156+ TEXT_KEY : table_to_markdown (figure_extracted_table ),
157+ }
158+ elif content .type in (
159+ ContentCategory .TABLE_CELL .value ,
160+ ContentCategory .FIGURE_EXTRACTED_TABLE_CELL .value ,
161+ ):
136162 # Skip - already accounted for in tables
137163 return {}
138164 # For texts and titles, add the text content and the category
@@ -153,6 +179,7 @@ def _get_segments_from_all_children(
153179 content : ContentModel ,
154180 uid_to_index : dict [str , tuple [int , int ]],
155181 uid_to_span : dict [str , tuple [int , int ]],
182+ figure_extracted_table_uid_to_cell_annotations : dict [str , list [AnnotationModel ]],
156183 return_locations : bool ,
157184 segments : list [dict [str , Any ]],
158185 visited : list [str ],
@@ -162,7 +189,12 @@ def _get_segments_from_all_children(
162189 return
163190
164191 # Get current segment from content and add to list
165- segment = _create_segment (content , uid_to_index , uid_to_span )
192+ segment = _create_segment (
193+ content ,
194+ uid_to_index ,
195+ uid_to_span ,
196+ figure_extracted_table_uid_to_cell_annotations ,
197+ )
166198 visited .append (content .uid )
167199 if segment :
168200 if return_locations :
@@ -172,7 +204,13 @@ def _get_segments_from_all_children(
172204 # Get all children segments
173205 for child in content .children :
174206 _get_segments_from_all_children (
175- child , uid_to_index , uid_to_span , return_locations , segments , visited
207+ child ,
208+ uid_to_index ,
209+ uid_to_span ,
210+ figure_extracted_table_uid_to_cell_annotations ,
211+ return_locations ,
212+ segments ,
213+ visited ,
176214 )
177215
178216
@@ -213,15 +251,38 @@ def convert_output_to_items_list(
213251 for uid in content_uids :
214252 uid_to_index [uid ] = (row , col )
215253 uid_to_span [uid ] = annotation .data .span
254+ elif annotation .type == AnnotationType .FIGURE_EXTRACTED_TABLE_STRUCTURE .value :
255+ continue
216256 else :
217257 raise TypeError (f"{ annotation .type } is not a supported annotation type" )
218258
259+ figure_extracted_table_cell_annotations = [
260+ annotation
261+ for annotation in annotations
262+ if annotation .type == AnnotationType .FIGURE_EXTRACTED_TABLE_STRUCTURE .value
263+ ]
264+ table_uid_to_cells_mapping = get_table_uid_to_cells_mapping (
265+ parsed_serialized_document .content_tree
266+ )
267+ figure_extracted_table_uid_to_cell_annotations = (
268+ get_table_uid_to_annotations_mapping (
269+ table_uid_to_cells_mapping ,
270+ figure_extracted_table_cell_annotations ,
271+ )
272+ )
273+
219274 # Parse content into segments
220275 content_tree = parsed_serialized_document .content_tree
221276 segments : list [dict [str , Any ]] = []
222277 visited : list [str ] = []
223278 _get_segments_from_all_children (
224- content_tree , uid_to_index , uid_to_span , return_locations , segments , visited
279+ content_tree ,
280+ uid_to_index ,
281+ uid_to_span ,
282+ figure_extracted_table_uid_to_cell_annotations ,
283+ return_locations ,
284+ segments ,
285+ visited ,
225286 )
226287 return segments
227288
0 commit comments