99from kensho_kenverters .extract_output_models import AnnotationDataModel , AnnotationModel
1010
1111
12- def _check_complete_set (integer_set : set [int ]) -> bool :
13- """Check that the set of integers contains all integers between 0 and its max."""
14- return integer_set == set (range (max (integer_set , default = - 1 ) + 1 ))
12+ def _create_empty_annotation (row : int , col : int ) -> AnnotationModel :
13+ """Create an empty annotation."""
14+ return AnnotationModel (
15+ type = AnnotationType .TABLE_STRUCTURE .value ,
16+ content_uids = [],
17+ data = AnnotationDataModel (
18+ span = (1 , 1 ),
19+ index = (row , col ),
20+ ),
21+ locations = None ,
22+ )
23+
24+
25+ def _validate_annotations (
26+ duplicated_annotations : list [AnnotationModel ], max_row : int , max_col : int
27+ ) -> list [AnnotationModel ]:
28+ """Validate duplicated annotations.
1529
30+ Fill with empty annotations if rows or columns are missing.
31+ """
1632
17- def _validate_annotations (duplicated_annotations : Sequence [AnnotationModel ]) -> None :
18- """Validate duplicated annotations."""
1933 # Check all spans are 1 (annotations are duplicated)
2034 all_spans = [annotation .data .span for annotation in duplicated_annotations ]
2135 if any (span != (1 , 1 ) for span in all_spans ):
@@ -26,13 +40,13 @@ def _validate_annotations(duplicated_annotations: Sequence[AnnotationModel]) ->
2640 if len (set (all_indices )) != len (all_indices ):
2741 raise ValueError ("Overlapping indices in table." )
2842
29- # Check no empty rows or columns
30- all_rows = set ( index [ 0 ] for index in all_indices )
31- all_columns = set ( index [ 1 ] for index in all_indices )
32- if not _check_complete_set ( all_rows ) :
33- raise ValueError ( "Empty row in table." )
34- if not _check_complete_set ( all_columns ):
35- raise ValueError ( "Empty column in table." )
43+ # Add any missing cells
44+ for row in range ( max_row + 1 ):
45+ for col in range ( max_col + 1 ):
46+ if ( row , col ) not in all_indices :
47+ duplicated_annotations . append ( _create_empty_annotation ( row , col ) )
48+
49+ return duplicated_annotations
3650
3751
3852def duplicate_spanning_annotations (
@@ -52,6 +66,8 @@ def duplicate_spanning_annotations(
5266 duplicated annotations. Duplicated annotations must all have span (1, 1).
5367 """
5468 duplicated_annotations = []
69+ max_row = 0
70+ max_col = 0
5571 for annotation in annotations :
5672 data = annotation .data
5773 row_span , col_span = data .span
@@ -76,9 +92,10 @@ def duplicate_spanning_annotations(
7692 locations = annotation .locations ,
7793 )
7894 duplicated_annotations .append (new_annotation )
95+ max_row = max (max_row , row_index + row_span_index )
96+ max_col = max (max_col , col_index + col_span_index )
7997
80- _validate_annotations (duplicated_annotations )
81- return duplicated_annotations
98+ return _validate_annotations (duplicated_annotations , max_row , max_col )
8299
83100
84101def get_table_shape (
0 commit comments