Skip to content

Commit 4a020d6

Browse files
Merge pull request #19 from kensho-technologies/libin/add-figex-table
[Kenverters] Adding support of FigEx table
2 parents 3834f7f + 0b20f53 commit 4a020d6

File tree

9 files changed

+1807
-48
lines changed

9 files changed

+1807
-48
lines changed

kensho_kenverters/CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
## v1.2.3
44

5-
Remove table validation for rows and columns to not fail downstream of Extract model failures
5+
* Remove table validation for rows and columns to not fail downstream of Extract model failures
6+
7+
* Added support of figure extracted table.
68

79
## v1.2.2
810

kensho_kenverters/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class AnnotationType(Enum):
1717
"""Enum for the annotation type from the Extract output."""
1818

1919
TABLE_STRUCTURE = "table_structure"
20+
FIGURE_EXTRACTED_TABLE_STRUCTURE = "figure_extracted_table_structure"
2021

2122

2223
class ContentCategory(Enum):
@@ -51,6 +52,8 @@ class ContentCategory(Enum):
5152
PAGE_FOOTNOTE = "PAGE_FOOTNOTE"
5253
TABLE_OF_CONTENTS = "TABLE_OF_CONTENTS"
5354
TABLE_OF_CONTENTS_TITLE = "TABLE_OF_CONTENTS_TITLE"
55+
FIGURE_EXTRACTED_TABLE = "FIGURE_EXTRACTED_TABLE"
56+
FIGURE_EXTRACTED_TABLE_CELL = "FIGURE_EXTRACTED_TABLE_CELL"
5457

5558

5659
ELEMENT_TITLE_CONTENT_CATEGORIES = {
@@ -73,4 +76,5 @@ class ContentCategory(Enum):
7376
TABLE_CONTENT_CATEGORIES = {
7477
ContentCategory.TABLE.value,
7578
ContentCategory.TABLE_OF_CONTENTS.value,
79+
ContentCategory.FIGURE_EXTRACTED_TABLE.value,
7680
}

kensho_kenverters/convert_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def convert_output_to_markdown(serialized_document: dict[str, Any]) -> str:
290290

291291

292292
def convert_output_to_markdown_by_page(
293-
serialized_document: dict[str, Any]
293+
serialized_document: dict[str, Any],
294294
) -> list[str]:
295295
r"""Convert entire Extract output into a markdown string per page.
296296

kensho_kenverters/convert_output_visual_formatted.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _get_segments_from_table_cells(
5050

5151

5252
def _convert_output_to_texts_with_locs(
53-
serialized_document: dict[str, Any]
53+
serialized_document: dict[str, Any],
5454
) -> list[dict[str, Any]]:
5555
"""Convert Extract output into a list of items.
5656

kensho_kenverters/extract_output_models.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
11
# Copyright 2024-present Kensho Technologies, LLC.
22
"""Pydantic models for the output JSON."""
33

4-
from typing import NamedTuple, TypeAlias
4+
from typing import Literal, NamedTuple, TypeAlias
55

66
import pandas as pd
77
from pydantic import BaseModel # pylint: disable=no-name-in-module
88

99
# Location types are either dictionaries of bbox coordinates and page numbers
1010
# or None if locations are not returned in the Extract output.
1111
LocationType: TypeAlias = dict[str, float | int] | None
12+
TableCategoryType: TypeAlias = Literal[
13+
"TABLE",
14+
"TABLE_OF_CONTENTS",
15+
"FIGURE_EXTRACTED_TABLE",
16+
]
1217

1318

1419
class Table(NamedTuple):
1520
"""Converted table types consisting of the table as a pandas DataFrame and its location(s)."""
1621

1722
df: pd.DataFrame
23+
table_type: TableCategoryType
1824
locations: list[LocationType] | None = None
1925

2026

@@ -33,6 +39,7 @@ class AnnotationDataModel(BaseModel):
3339

3440
index: tuple[int, int]
3541
span: tuple[int, int]
42+
value: str | None = None
3643

3744

3845
class AnnotationModel(BaseModel):

kensho_kenverters/output_to_sections.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
def extract_organized_sections(
11-
serialized_document: dict[str, Any]
11+
serialized_document: dict[str, Any],
1212
) -> list[list[dict[str, Any]]]:
1313
r"""Return a version of the output organized into sections split on titles.
1414

kensho_kenverters/output_to_tables.py

Lines changed: 131 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright 2024-present Kensho Technologies, LLC.
22
"""Functions to extract the tables in the output and turn them into pandas DataFrames."""
33

4+
import typing
45
from collections import defaultdict
56
from typing import Any, Sequence
67

@@ -17,6 +18,7 @@
1718
LocationModel,
1819
LocationType,
1920
Table,
21+
TableCategoryType,
2022
)
2123
from kensho_kenverters.tables_utils import (
2224
convert_table_to_pd_df,
@@ -36,7 +38,11 @@ def _get_table_uid_to_cells_mapping(
3638
cells = [
3739
child
3840
for child in content.children
39-
if child.type == ContentCategory.TABLE_CELL.value
41+
if child.type
42+
in (
43+
ContentCategory.TABLE_CELL.value,
44+
ContentCategory.FIGURE_EXTRACTED_TABLE_CELL.value,
45+
)
4046
]
4147
current_mapping[content.uid] = cells
4248
elif len(content.children) > 0:
@@ -47,6 +53,22 @@ def _get_table_uid_to_cells_mapping(
4753
return current_mapping
4854

4955

56+
def _get_table_uid_to_types_mapping(
57+
content: ContentModel,
58+
) -> dict[str, TableCategoryType]:
59+
"""Recursively get table uids to table types mapping."""
60+
table_uid_to_types: dict[str, TableCategoryType] = {}
61+
if content.type in TABLE_CONTENT_CATEGORIES:
62+
# Termination condition 1
63+
table_uid_to_types[content.uid] = typing.cast(TableCategoryType, content.type)
64+
elif len(content.children) > 0:
65+
for child in content.children:
66+
# Recursive call to children
67+
nested_mapping = _get_table_uid_to_types_mapping(child)
68+
table_uid_to_types.update(nested_mapping)
69+
return table_uid_to_types
70+
71+
5072
def _get_table_uid_to_locations_mapping(
5173
content: ContentModel,
5274
) -> dict[str, list[LocationType]]:
@@ -83,31 +105,30 @@ def _get_table_uid_to_annotations_mapping(
83105
return table_to_annotations
84106

85107

86-
def _build_grid_from_table_cell_annotations(
87-
annotations: Sequence[AnnotationModel], duplicate_content_flag: bool = False
108+
def _build_uids_grid_from_table_cell_annotations(
109+
annotations: Sequence[AnnotationModel],
110+
duplicate_content_flag: bool = False,
88111
) -> list[list[list[str]]]:
89112
"""Build grid where each location has a list of content uids."""
90113
if any(
91114
annotation.type != AnnotationType.TABLE_STRUCTURE.value
92115
for annotation in annotations
93116
):
94117
raise ValueError(
95-
"Table grid can only be built from table structure annotations."
118+
"Content uids grid can only be built from table structure annotations."
96119
)
97-
98120
duplicated_annotations = duplicate_spanning_annotations(
99121
annotations, duplicate_content_flag
100122
)
123+
101124
index_to_uids_mapping = defaultdict(
102125
list,
103126
{
104127
annotation.data.index: annotation.content_uids
105128
for annotation in duplicated_annotations
106129
},
107130
)
108-
109131
n_rows, n_cols = get_table_shape(duplicated_annotations)
110-
111132
rows: list[list[list[str]]] = []
112133
for row_index in range(n_rows):
113134
current_row = []
@@ -117,11 +138,51 @@ def _build_grid_from_table_cell_annotations(
117138
return rows
118139

119140

141+
def _build_content_grid_from_figure_extracted_table_cell_annotations(
142+
annotations: Sequence[AnnotationModel],
143+
) -> list[list[str]]:
144+
"""Build content grid where each location has a string of content."""
145+
if any(
146+
annotation.type != AnnotationType.FIGURE_EXTRACTED_TABLE_STRUCTURE.value
147+
for annotation in annotations
148+
):
149+
raise ValueError(
150+
"Content grid can only be built from figure extracted table structure annotations."
151+
)
152+
153+
if any(annotation.data.value is None for annotation in annotations):
154+
raise ValueError(
155+
"Data value of figure extracted table structure "
156+
"annotations cannot be None."
157+
)
158+
# If annotations are figure extracted table structure, we fill the grids
159+
# with extracted values.
160+
n_rows, n_cols = get_table_shape(annotations)
161+
index_to_annotation_value_mapping = {}
162+
for annotation in annotations:
163+
if annotation.data.value is not None:
164+
index_to_annotation_value_mapping[annotation.data.index] = (
165+
annotation.data.value
166+
)
167+
else:
168+
index_to_annotation_value_mapping[annotation.data.index] = ""
169+
rows: list[list[str]] = []
170+
for row_index in range(n_rows):
171+
current_content_row = []
172+
for col_index in range(n_cols):
173+
current_content_row.append(
174+
index_to_annotation_value_mapping[(row_index, col_index)]
175+
)
176+
rows.append(current_content_row)
177+
return rows
178+
179+
120180
def _convert_uid_grid_to_content_grid(
121181
uid_grid: list[list[list[str]]], cell_contents: Sequence[ContentModel]
122182
) -> list[list[str]]:
123183
"""Convert a UID grid to content grid."""
124184
uids_to_content = {cell.uid: cell.content for cell in cell_contents}
185+
125186
content_grid = []
126187
for uid_row in uid_grid:
127188
content_row = []
@@ -145,8 +206,8 @@ def _convert_uid_grid_to_content_grid(
145206
def build_table_grids(
146207
serialized_document: dict[str, Any],
147208
duplicate_merged_cells_content_flag: bool = True,
148-
) -> dict[str, list[list[str]]]:
149-
"""Convert serialized tables to a 2D grid of strings.
209+
) -> dict[str, tuple[TableCategoryType, list[list[str]]]]:
210+
"""Convert serialized tables to a table type and a 2D grid of strings.
150211
151212
Args:
152213
serialized_document: a serialized document
@@ -155,44 +216,62 @@ def build_table_grids(
155216
empty.
156217
157218
Returns:
158-
a mapping of table UIDs to table grid structures
219+
a mapping of table UIDs to the tuple of table type and table grid structures
159220
160221
Example Output:
161222
{
162-
'1': [['header1', 'header2'], ['row1_val', 'row2_val']],
163-
'2': [['another_header1'], ['another_row1_val']]
223+
'1': ("TABLE",[['header1', 'header2'], ['row1_val', 'row2_val']]),
224+
'2': ("FIGURE_EXTRACTED_TABLE",[['another_header1'], ['another_row1_val']])
164225
}
165226
"""
166227
parsed_serialized_document = load_output_to_pydantic(serialized_document)
167228
annotations = parsed_serialized_document.annotations
168229
content = parsed_serialized_document.content_tree
169230

170231
table_uid_to_cells_mapping = _get_table_uid_to_cells_mapping(content)
232+
table_uid_to_type_mapping = _get_table_uid_to_types_mapping(content)
171233

172234
table_cell_annotations = [
173235
annotation
174236
for annotation in annotations
175-
if annotation.type == AnnotationType.TABLE_STRUCTURE.value
237+
if annotation.type
238+
in (
239+
AnnotationType.TABLE_STRUCTURE.value,
240+
AnnotationType.FIGURE_EXTRACTED_TABLE_STRUCTURE.value,
241+
)
176242
]
177243
table_uid_to_cell_annotations = _get_table_uid_to_annotations_mapping(
178244
table_uid_to_cells_mapping, table_cell_annotations
179245
)
180246

181247
tables = {}
182248
for table_uid, cell_annotations in table_uid_to_cell_annotations.items():
183-
grid = _build_grid_from_table_cell_annotations(
184-
cell_annotations, duplicate_content_flag=duplicate_merged_cells_content_flag
185-
)
186-
cell_contents = table_uid_to_cells_mapping[table_uid]
187-
content_grid = _convert_uid_grid_to_content_grid(grid, cell_contents)
188-
tables[table_uid] = content_grid
249+
if table_uid_to_type_mapping[table_uid] in (
250+
ContentCategory.TABLE.value,
251+
ContentCategory.TABLE_OF_CONTENTS.value,
252+
):
253+
uids_grid = _build_uids_grid_from_table_cell_annotations(
254+
cell_annotations,
255+
duplicate_content_flag=duplicate_merged_cells_content_flag,
256+
)
257+
cell_contents = table_uid_to_cells_mapping[table_uid]
258+
content_grid = _convert_uid_grid_to_content_grid(uids_grid, cell_contents)
259+
tables[table_uid] = (table_uid_to_type_mapping[table_uid], content_grid)
260+
else:
261+
content_grid = (
262+
_build_content_grid_from_figure_extracted_table_cell_annotations(
263+
cell_annotations
264+
)
265+
)
266+
tables[table_uid] = (table_uid_to_type_mapping[table_uid], content_grid)
189267
return tables
190268

191269

192270
def extract_pd_dfs_from_output(
193271
serialized_document: dict[str, Any],
194272
duplicate_merged_cells_content_flag: bool = True,
195273
use_first_row_as_header: bool = True,
274+
include_figure_extracted_table: bool = False,
196275
) -> list[pd.DataFrame]:
197276
"""Extract Extract output's tables and convert them to a list of pandas DataFrames.
198277
@@ -214,15 +293,22 @@ def extract_pd_dfs_from_output(
214293
2 2022 102,004 202,004 302,004 402,004
215294
3 2023 103,009 203,009 303,009 403,009]
216295
"""
217-
table_grids = build_table_grids(
296+
table_types_and_grids = build_table_grids(
218297
serialized_document, duplicate_merged_cells_content_flag
219298
)
220299
table_dfs = []
221-
for table_grid in table_grids.values():
222-
table_df = convert_table_to_pd_df(
223-
table_grid, use_first_row_as_header=use_first_row_as_header
224-
)
225-
table_dfs.append(table_df)
300+
for table_type_and_grid in table_types_and_grids.values():
301+
if table_type_and_grid[0] in (
302+
ContentCategory.TABLE.value,
303+
ContentCategory.TABLE_OF_CONTENTS.value,
304+
) or (
305+
include_figure_extracted_table
306+
and table_type_and_grid[0] == ContentCategory.FIGURE_EXTRACTED_TABLE.value
307+
):
308+
table_df = convert_table_to_pd_df(
309+
table_type_and_grid[1], use_first_row_as_header=use_first_row_as_header
310+
)
311+
table_dfs.append(table_df)
226312

227313
return table_dfs
228314

@@ -231,6 +317,7 @@ def extract_pd_dfs_with_locs_from_output(
231317
serialized_document: dict[str, Any],
232318
duplicate_merged_cells_content_flag: bool = True,
233319
use_first_row_as_header: bool = True,
320+
include_figure_extracted_table: bool = False,
234321
) -> list[Table]:
235322
"""Extract tables from output and convert them to a list of pd DataFrames and table locations.
236323
@@ -258,7 +345,7 @@ def extract_pd_dfs_with_locs_from_output(
258345
)]
259346
"""
260347
# Get dfs
261-
table_grids = build_table_grids(
348+
table_types_and_grids = build_table_grids(
262349
serialized_document, duplicate_merged_cells_content_flag
263350
)
264351

@@ -270,11 +357,22 @@ def extract_pd_dfs_with_locs_from_output(
270357

271358
# Match dfs and locations
272359
tables: list[Table] = []
273-
for table_uid, table_grid in table_grids.items():
274-
table_df = convert_table_to_pd_df(
275-
table_grid, use_first_row_as_header=use_first_row_as_header
276-
)
277-
tables.append(
278-
Table(df=table_df, locations=table_uid_to_locs_mapping[table_uid])
279-
)
360+
for table_uid, table_type_and_grid in table_types_and_grids.items():
361+
if table_type_and_grid[0] in (
362+
ContentCategory.TABLE.value,
363+
ContentCategory.TABLE_OF_CONTENTS.value,
364+
) or (
365+
include_figure_extracted_table
366+
and table_type_and_grid[0] == ContentCategory.FIGURE_EXTRACTED_TABLE.value
367+
):
368+
table_df = convert_table_to_pd_df(
369+
table_type_and_grid[1], use_first_row_as_header=use_first_row_as_header
370+
)
371+
tables.append(
372+
Table(
373+
df=table_df,
374+
table_type=table_type_and_grid[0],
375+
locations=table_uid_to_locs_mapping[table_uid],
376+
)
377+
)
280378
return tables

0 commit comments

Comments
 (0)