11# Copyright 2024-present Kensho Technologies, LLC.
22"""Functions to extract the tables in the output and turn them into pandas DataFrames."""
33
4+ import typing
45from collections import defaultdict
56from typing import Any , Sequence
67
1718 LocationModel ,
1819 LocationType ,
1920 Table ,
21+ TableCategoryType ,
2022)
2123from kensho_kenverters .tables_utils import (
2224 convert_table_to_pd_df ,
@@ -36,7 +38,11 @@ def _get_table_uid_to_cells_mapping(
3638 cells = [
3739 child
3840 for child in content .children
39- if child .type == ContentCategory .TABLE_CELL .value
41+ if child .type
42+ in (
43+ ContentCategory .TABLE_CELL .value ,
44+ ContentCategory .FIGURE_EXTRACTED_TABLE_CELL .value ,
45+ )
4046 ]
4147 current_mapping [content .uid ] = cells
4248 elif len (content .children ) > 0 :
@@ -47,6 +53,22 @@ def _get_table_uid_to_cells_mapping(
4753 return current_mapping
4854
4955
56+ def _get_table_uid_to_types_mapping (
57+ content : ContentModel ,
58+ ) -> dict [str , TableCategoryType ]:
59+ """Recursively get table uids to table types mapping."""
60+ table_uid_to_types : dict [str , TableCategoryType ] = {}
61+ if content .type in TABLE_CONTENT_CATEGORIES :
62+ # Termination condition 1
63+ table_uid_to_types [content .uid ] = typing .cast (TableCategoryType , content .type )
64+ elif len (content .children ) > 0 :
65+ for child in content .children :
66+ # Recursive call to children
67+ nested_mapping = _get_table_uid_to_types_mapping (child )
68+ table_uid_to_types .update (nested_mapping )
69+ return table_uid_to_types
70+
71+
5072def _get_table_uid_to_locations_mapping (
5173 content : ContentModel ,
5274) -> dict [str , list [LocationType ]]:
@@ -83,31 +105,30 @@ def _get_table_uid_to_annotations_mapping(
83105 return table_to_annotations
84106
85107
86- def _build_grid_from_table_cell_annotations (
87- annotations : Sequence [AnnotationModel ], duplicate_content_flag : bool = False
108+ def _build_uids_grid_from_table_cell_annotations (
109+ annotations : Sequence [AnnotationModel ],
110+ duplicate_content_flag : bool = False ,
88111) -> list [list [list [str ]]]:
89112 """Build grid where each location has a list of content uids."""
90113 if any (
91114 annotation .type != AnnotationType .TABLE_STRUCTURE .value
92115 for annotation in annotations
93116 ):
94117 raise ValueError (
95- "Table grid can only be built from table structure annotations."
118+ "Content uids grid can only be built from table structure annotations."
96119 )
97-
98120 duplicated_annotations = duplicate_spanning_annotations (
99121 annotations , duplicate_content_flag
100122 )
123+
101124 index_to_uids_mapping = defaultdict (
102125 list ,
103126 {
104127 annotation .data .index : annotation .content_uids
105128 for annotation in duplicated_annotations
106129 },
107130 )
108-
109131 n_rows , n_cols = get_table_shape (duplicated_annotations )
110-
111132 rows : list [list [list [str ]]] = []
112133 for row_index in range (n_rows ):
113134 current_row = []
@@ -117,11 +138,51 @@ def _build_grid_from_table_cell_annotations(
117138 return rows
118139
119140
141+ def _build_content_grid_from_figure_extracted_table_cell_annotations (
142+ annotations : Sequence [AnnotationModel ],
143+ ) -> list [list [str ]]:
144+ """Build content grid where each location has a string of content."""
145+ if any (
146+ annotation .type != AnnotationType .FIGURE_EXTRACTED_TABLE_STRUCTURE .value
147+ for annotation in annotations
148+ ):
149+ raise ValueError (
150+ "Content grid can only be built from figure extracted table structure annotations."
151+ )
152+
153+ if any (annotation .data .value is None for annotation in annotations ):
154+ raise ValueError (
155+ "Data value of figure extracted table structure "
156+ "annotations cannot be None."
157+ )
158+ # If annotations are figure extracted table structure, we fill the grids
159+ # with extracted values.
160+ n_rows , n_cols = get_table_shape (annotations )
161+ index_to_annotation_value_mapping = {}
162+ for annotation in annotations :
163+ if annotation .data .value is not None :
164+ index_to_annotation_value_mapping [annotation .data .index ] = (
165+ annotation .data .value
166+ )
167+ else :
168+ index_to_annotation_value_mapping [annotation .data .index ] = ""
169+ rows : list [list [str ]] = []
170+ for row_index in range (n_rows ):
171+ current_content_row = []
172+ for col_index in range (n_cols ):
173+ current_content_row .append (
174+ index_to_annotation_value_mapping [(row_index , col_index )]
175+ )
176+ rows .append (current_content_row )
177+ return rows
178+
179+
120180def _convert_uid_grid_to_content_grid (
121181 uid_grid : list [list [list [str ]]], cell_contents : Sequence [ContentModel ]
122182) -> list [list [str ]]:
123183 """Convert a UID grid to content grid."""
124184 uids_to_content = {cell .uid : cell .content for cell in cell_contents }
185+
125186 content_grid = []
126187 for uid_row in uid_grid :
127188 content_row = []
@@ -145,8 +206,8 @@ def _convert_uid_grid_to_content_grid(
145206def build_table_grids (
146207 serialized_document : dict [str , Any ],
147208 duplicate_merged_cells_content_flag : bool = True ,
148- ) -> dict [str , list [list [str ]]]:
149- """Convert serialized tables to a 2D grid of strings.
209+ ) -> dict [str , tuple [ TableCategoryType , list [list [str ] ]]]:
210+ """Convert serialized tables to a table type and a 2D grid of strings.
150211
151212 Args:
152213 serialized_document: a serialized document
@@ -155,44 +216,62 @@ def build_table_grids(
155216 empty.
156217
157218 Returns:
158- a mapping of table UIDs to table grid structures
219+ a mapping of table UIDs to the tuple of table type and table grid structures
159220
160221 Example Output:
161222 {
162- '1': [['header1', 'header2'], ['row1_val', 'row2_val']],
163- '2': [['another_header1'], ['another_row1_val']]
223+ '1': ("TABLE", [['header1', 'header2'], ['row1_val', 'row2_val']]) ,
224+ '2': ("FIGURE_EXTRACTED_TABLE", [['another_header1'], ['another_row1_val']])
164225 }
165226 """
166227 parsed_serialized_document = load_output_to_pydantic (serialized_document )
167228 annotations = parsed_serialized_document .annotations
168229 content = parsed_serialized_document .content_tree
169230
170231 table_uid_to_cells_mapping = _get_table_uid_to_cells_mapping (content )
232+ table_uid_to_type_mapping = _get_table_uid_to_types_mapping (content )
171233
172234 table_cell_annotations = [
173235 annotation
174236 for annotation in annotations
175- if annotation .type == AnnotationType .TABLE_STRUCTURE .value
237+ if annotation .type
238+ in (
239+ AnnotationType .TABLE_STRUCTURE .value ,
240+ AnnotationType .FIGURE_EXTRACTED_TABLE_STRUCTURE .value ,
241+ )
176242 ]
177243 table_uid_to_cell_annotations = _get_table_uid_to_annotations_mapping (
178244 table_uid_to_cells_mapping , table_cell_annotations
179245 )
180246
181247 tables = {}
182248 for table_uid , cell_annotations in table_uid_to_cell_annotations .items ():
183- grid = _build_grid_from_table_cell_annotations (
184- cell_annotations , duplicate_content_flag = duplicate_merged_cells_content_flag
185- )
186- cell_contents = table_uid_to_cells_mapping [table_uid ]
187- content_grid = _convert_uid_grid_to_content_grid (grid , cell_contents )
188- tables [table_uid ] = content_grid
249+ if table_uid_to_type_mapping [table_uid ] in (
250+ ContentCategory .TABLE .value ,
251+ ContentCategory .TABLE_OF_CONTENTS .value ,
252+ ):
253+ uids_grid = _build_uids_grid_from_table_cell_annotations (
254+ cell_annotations ,
255+ duplicate_content_flag = duplicate_merged_cells_content_flag ,
256+ )
257+ cell_contents = table_uid_to_cells_mapping [table_uid ]
258+ content_grid = _convert_uid_grid_to_content_grid (uids_grid , cell_contents )
259+ tables [table_uid ] = (table_uid_to_type_mapping [table_uid ], content_grid )
260+ else :
261+ content_grid = (
262+ _build_content_grid_from_figure_extracted_table_cell_annotations (
263+ cell_annotations
264+ )
265+ )
266+ tables [table_uid ] = (table_uid_to_type_mapping [table_uid ], content_grid )
189267 return tables
190268
191269
192270def extract_pd_dfs_from_output (
193271 serialized_document : dict [str , Any ],
194272 duplicate_merged_cells_content_flag : bool = True ,
195273 use_first_row_as_header : bool = True ,
274+ include_figure_extracted_table : bool = False ,
196275) -> list [pd .DataFrame ]:
197276 """Extract Extract output's tables and convert them to a list of pandas DataFrames.
198277
@@ -214,15 +293,22 @@ def extract_pd_dfs_from_output(
214293 2 2022 102,004 202,004 302,004 402,004
215294 3 2023 103,009 203,009 303,009 403,009]
216295 """
217- table_grids = build_table_grids (
296+ table_types_and_grids = build_table_grids (
218297 serialized_document , duplicate_merged_cells_content_flag
219298 )
220299 table_dfs = []
221- for table_grid in table_grids .values ():
222- table_df = convert_table_to_pd_df (
223- table_grid , use_first_row_as_header = use_first_row_as_header
224- )
225- table_dfs .append (table_df )
300+ for table_type_and_grid in table_types_and_grids .values ():
301+ if table_type_and_grid [0 ] in (
302+ ContentCategory .TABLE .value ,
303+ ContentCategory .TABLE_OF_CONTENTS .value ,
304+ ) or (
305+ include_figure_extracted_table
306+ and table_type_and_grid [0 ] == ContentCategory .FIGURE_EXTRACTED_TABLE .value
307+ ):
308+ table_df = convert_table_to_pd_df (
309+ table_type_and_grid [1 ], use_first_row_as_header = use_first_row_as_header
310+ )
311+ table_dfs .append (table_df )
226312
227313 return table_dfs
228314
@@ -231,6 +317,7 @@ def extract_pd_dfs_with_locs_from_output(
231317 serialized_document : dict [str , Any ],
232318 duplicate_merged_cells_content_flag : bool = True ,
233319 use_first_row_as_header : bool = True ,
320+ include_figure_extracted_table : bool = False ,
234321) -> list [Table ]:
235322 """Extract tables from output and convert them to a list of pd DataFrames and table locations.
236323
@@ -258,7 +345,7 @@ def extract_pd_dfs_with_locs_from_output(
258345 )]
259346 """
260347 # Get dfs
261- table_grids = build_table_grids (
348+ table_types_and_grids = build_table_grids (
262349 serialized_document , duplicate_merged_cells_content_flag
263350 )
264351
@@ -270,11 +357,22 @@ def extract_pd_dfs_with_locs_from_output(
270357
271358 # Match dfs and locations
272359 tables : list [Table ] = []
273- for table_uid , table_grid in table_grids .items ():
274- table_df = convert_table_to_pd_df (
275- table_grid , use_first_row_as_header = use_first_row_as_header
276- )
277- tables .append (
278- Table (df = table_df , locations = table_uid_to_locs_mapping [table_uid ])
279- )
360+ for table_uid , table_type_and_grid in table_types_and_grids .items ():
361+ if table_type_and_grid [0 ] in (
362+ ContentCategory .TABLE .value ,
363+ ContentCategory .TABLE_OF_CONTENTS .value ,
364+ ) or (
365+ include_figure_extracted_table
366+ and table_type_and_grid [0 ] == ContentCategory .FIGURE_EXTRACTED_TABLE .value
367+ ):
368+ table_df = convert_table_to_pd_df (
369+ table_type_and_grid [1 ], use_first_row_as_header = use_first_row_as_header
370+ )
371+ tables .append (
372+ Table (
373+ df = table_df ,
374+ table_type = table_type_and_grid [0 ],
375+ locations = table_uid_to_locs_mapping [table_uid ],
376+ )
377+ )
280378 return tables
0 commit comments