@@ -515,35 +515,27 @@ def multi_table_predict(
515515 indexing_start_cols = (
516516 []
517517 ) # Index of original start col IDs (not indexes)
518- indexing_end_cols = [] # Index of original end col IDs (not indexes)
519518 indexing_start_rows = (
520519 []
521520 ) # Index of original start row IDs (not indexes)
522- indexing_end_rows = [] # Index of original end row IDs (not indexes)
523521
524522 # First, collect all possible predicted IDs, to be used as indexes
525523 # ID's returned by Tableformer are sequential, but might contain gaps
526524 for tf_response_cell in tf_responses :
527525 start_col_offset_idx = tf_response_cell ["start_col_offset_idx" ]
528- end_col_offset_idx = tf_response_cell ["end_col_offset_idx" ]
529526 start_row_offset_idx = tf_response_cell ["start_row_offset_idx" ]
530- end_row_offset_idx = tf_response_cell ["end_row_offset_idx" ]
531527
532528 # Collect all possible col/row IDs:
533529 if start_col_offset_idx not in indexing_start_cols :
534530 indexing_start_cols .append (start_col_offset_idx )
535- if end_col_offset_idx not in indexing_end_cols :
536- indexing_end_cols .append (end_col_offset_idx )
537531 if start_row_offset_idx not in indexing_start_rows :
538532 indexing_start_rows .append (start_row_offset_idx )
539- if end_row_offset_idx not in indexing_end_rows :
540- indexing_end_rows .append (end_row_offset_idx )
541533
542534 indexing_start_cols .sort ()
543- indexing_end_cols .sort ()
544535 indexing_start_rows .sort ()
545- indexing_end_rows .sort ()
546536
537+ max_end_col_idx = 0
538+ max_end_row_idx = 0
547539 # After this - put actual indexes of IDs back into predicted structure...
548540 for tf_response_cell in tf_responses :
549541 tf_response_cell ["start_col_offset_idx" ] = (
@@ -555,6 +547,9 @@ def multi_table_predict(
555547 tf_response_cell ["start_col_offset_idx" ]
556548 + tf_response_cell ["col_span" ]
557549 )
550+ max_end_col_idx = max (
551+ max_end_col_idx , tf_response_cell ["end_col_offset_idx" ]
552+ )
558553 tf_response_cell ["start_row_offset_idx" ] = (
559554 indexing_start_rows .index (
560555 tf_response_cell ["start_row_offset_idx" ]
@@ -564,9 +559,12 @@ def multi_table_predict(
564559 tf_response_cell ["start_row_offset_idx" ]
565560 + tf_response_cell ["row_span" ]
566561 )
562+ max_end_row_idx = max (
563+ max_end_row_idx , tf_response_cell ["end_row_offset_idx" ]
564+ )
567565 # Counting matched cols/rows from actual indexes (and not ids)
568- predict_details ["num_cols" ] = len ( indexing_end_cols )
569- predict_details ["num_rows" ] = len ( indexing_end_rows )
566+ predict_details ["num_cols" ] = max_end_col_idx
567+ predict_details ["num_rows" ] = max_end_row_idx
570568 else :
571569 otsl_seq = predict_details ["prediction" ]["rs_seq" ]
572570 predict_details ["num_cols" ] = otsl_seq .index ("nl" )
0 commit comments