Skip to content

Commit a45d146

Browse files
authored
fix: improve calculation of num_cols and num_rows (#126)
Signed-off-by: Michele Dolfi <[email protected]>
1 parent 320b3ec commit a45d146

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

docling_ibm_models/tableformer/data_management/tf_predictor.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -515,35 +515,27 @@ def multi_table_predict(
515515
indexing_start_cols = (
516516
[]
517517
) # Index of original start col IDs (not indexes)
518-
indexing_end_cols = [] # Index of original end col IDs (not indexes)
519518
indexing_start_rows = (
520519
[]
521520
) # Index of original start row IDs (not indexes)
522-
indexing_end_rows = [] # Index of original end row IDs (not indexes)
523521

524522
# First, collect all possible predicted IDs, to be used as indexes
525523
# ID's returned by Tableformer are sequential, but might contain gaps
526524
for tf_response_cell in tf_responses:
527525
start_col_offset_idx = tf_response_cell["start_col_offset_idx"]
528-
end_col_offset_idx = tf_response_cell["end_col_offset_idx"]
529526
start_row_offset_idx = tf_response_cell["start_row_offset_idx"]
530-
end_row_offset_idx = tf_response_cell["end_row_offset_idx"]
531527

532528
# Collect all possible col/row IDs:
533529
if start_col_offset_idx not in indexing_start_cols:
534530
indexing_start_cols.append(start_col_offset_idx)
535-
if end_col_offset_idx not in indexing_end_cols:
536-
indexing_end_cols.append(end_col_offset_idx)
537531
if start_row_offset_idx not in indexing_start_rows:
538532
indexing_start_rows.append(start_row_offset_idx)
539-
if end_row_offset_idx not in indexing_end_rows:
540-
indexing_end_rows.append(end_row_offset_idx)
541533

542534
indexing_start_cols.sort()
543-
indexing_end_cols.sort()
544535
indexing_start_rows.sort()
545-
indexing_end_rows.sort()
546536

537+
max_end_col_idx = 0
538+
max_end_row_idx = 0
547539
# After this - put actual indexes of IDs back into predicted structure...
548540
for tf_response_cell in tf_responses:
549541
tf_response_cell["start_col_offset_idx"] = (
@@ -555,6 +547,9 @@ def multi_table_predict(
555547
tf_response_cell["start_col_offset_idx"]
556548
+ tf_response_cell["col_span"]
557549
)
550+
max_end_col_idx = max(
551+
max_end_col_idx, tf_response_cell["end_col_offset_idx"]
552+
)
558553
tf_response_cell["start_row_offset_idx"] = (
559554
indexing_start_rows.index(
560555
tf_response_cell["start_row_offset_idx"]
@@ -564,9 +559,12 @@ def multi_table_predict(
564559
tf_response_cell["start_row_offset_idx"]
565560
+ tf_response_cell["row_span"]
566561
)
562+
max_end_row_idx = max(
563+
max_end_row_idx, tf_response_cell["end_row_offset_idx"]
564+
)
567565
# Counting matched cols/rows from actual indexes (and not ids)
568-
predict_details["num_cols"] = len(indexing_end_cols)
569-
predict_details["num_rows"] = len(indexing_end_rows)
566+
predict_details["num_cols"] = max_end_col_idx
567+
predict_details["num_rows"] = max_end_row_idx
570568
else:
571569
otsl_seq = predict_details["prediction"]["rs_seq"]
572570
predict_details["num_cols"] = otsl_seq.index("nl")

0 commit comments

Comments
 (0)