66from PIL import Image
77
88from ftfy import fix_text
9- from surya .recognition import RecognitionPredictor , OCRResult , TextLine
9+ from surya .detection import DetectionPredictor , TextDetectionResult
10+ from surya .recognition import RecognitionPredictor , TextLine
1011from surya .table_rec import TableRecPredictor
1112from surya .table_rec .schema import TableResult , TableCell as SuryaTableCell
1213from pdftext .extraction import table_output
@@ -35,6 +36,11 @@ class TableProcessor(BaseProcessor):
3536 "The batch size to use for the table recognition model." ,
3637 "Default is None, which will use the default batch size for the model." ,
3738 ] = None
39+ detection_batch_size : Annotated [
40+ int ,
41+ "The batch size to use for the table detection model." ,
42+ "Default is None, which will use the default batch size for the model." ,
43+ ] = None
3844 recognition_batch_size : Annotated [
3945 int ,
4046 "The batch size to use for the table recognition model." ,
@@ -56,27 +62,34 @@ class TableProcessor(BaseProcessor):
5662 bool ,
5763 "Whether to disable the tqdm progress bar." ,
5864 ] = False
59- drop_repeated_table_text : Annotated [bool , "Drop repeated text in OCR results." ] = False
65+ drop_repeated_table_text : Annotated [bool , "Drop repeated text in OCR results." ] = (
66+ False
67+ )
6068 filter_tag_list = ["p" , "table" , "td" , "tr" , "th" , "tbody" ]
6169 disable_ocr_math : Annotated [bool , "Disable inline math recognition in OCR" ] = False
70+ disable_ocr : Annotated [bool , "Disable OCR entirely." ] = False
6271
6372 def __init__ (
6473 self ,
6574 recognition_model : RecognitionPredictor ,
6675 table_rec_model : TableRecPredictor ,
76+ detection_model : DetectionPredictor ,
6777 config = None ,
6878 ):
6979 super ().__init__ (config )
7080
7181 self .recognition_model = recognition_model
7282 self .table_rec_model = table_rec_model
83+ self .detection_model = detection_model
7384
7485 def __call__ (self , document : Document ):
7586 filepath = document .filepath # Path to original pdf file
7687
7788 table_data = []
7889 for page in document .pages :
7990 for block in page .contained_blocks (document , self .block_types ):
91+ if block .block_type == BlockTypes .Table :
92+ block .polygon = block .polygon .expand (0.01 , 0.01 )
8093 image = block .get_image (document , highres = True )
8194 image_poly = block .polygon .rescale (
8295 (page .polygon .width , page .polygon .height ),
@@ -105,6 +118,9 @@ def __call__(self, document: Document):
105118 [t ["table_image" ] for t in table_data ],
106119 batch_size = self .get_table_rec_batch_size (),
107120 )
121+ assert len (tables ) == len (table_data ), (
122+ "Number of table results should match the number of tables"
123+ )
108124
109125 # Assign cell text if we don't need OCR
110126 # We do this at a line level
@@ -180,21 +196,21 @@ def finalize_cell_text(self, cell: SuryaTableCell):
180196 # Unspaced sequences: "...", "---", "___", "……"
181197 text = re .sub (r"[.\-_…]{2,}" , "" , text )
182198 # Remove mathbf formatting if there is only digits with decimals/commas/currency symbols inside
183- text = re .sub (r' \\mathbf\{([0-9.,$€£]+)\}' , r' <b>\1</b>' , text )
199+ text = re .sub (r" \\mathbf\{([0-9.,$€£]+)\}" , r" <b>\1</b>" , text )
184200 # Drop empty tags like \overline{}
185- text = re .sub (r' \\[a-zA-Z]+\{\s*\}' , '' , text )
201+ text = re .sub (r" \\[a-zA-Z]+\{\s*\}" , "" , text )
186202 # Drop \phantom{...} (remove contents too)
187- text = re .sub (r' \\phantom\{.*?\}' , '' , text )
203+ text = re .sub (r" \\phantom\{.*?\}" , "" , text )
188204 # Drop \quad
189- text = re .sub (r' \\quad' , '' , text )
205+ text = re .sub (r" \\quad" , "" , text )
190206 # Drop \,
191- text = re .sub (r' \\,' , '' , text )
207+ text = re .sub (r" \\," , "" , text )
192208 # Unwrap \mathsf{...}
193- text = re .sub (r' \\mathsf\{([^}]*)\}' , r'\1' , text )
209+ text = re .sub (r" \\mathsf\{([^}]*)\}" , r"\1" , text )
194210 # Handle unclosed tags: keep contents, drop the command
195- text = re .sub (r' \\[a-zA-Z]+\{([^}]*)$' , r'\1' , text )
211+ text = re .sub (r" \\[a-zA-Z]+\{([^}]*)$" , r"\1" , text )
196212 # If the whole string is \text{...} → unwrap
197- text = re .sub (r' ^\s*\\text\{([^}]*)\}\s*$' , r'\1' , text )
213+ text = re .sub (r" ^\s*\\text\{([^}]*)\}\s*$" , r"\1" , text )
198214
199215 # In case the above steps left no more latex math - We can unwrap
200216 text = unwrap_math (text )
@@ -479,31 +495,134 @@ def assign_pdftext_lines(self, extract_blocks: list, filepath: str):
479495 "Number of tables and table inputs must match"
480496 )
481497
482- def needs_ocr (self , tables : List [TableResult ]):
498+ def align_table_cells (
499+ self , table : TableResult , table_detection_result : TextDetectionResult
500+ ):
501+ table_cells = table .cells
502+ table_text_lines = table_detection_result .bboxes
503+
504+ text_line_bboxes = [t .bbox for t in table_text_lines ]
505+ table_cell_bboxes = [c .bbox for c in table_cells ]
506+
507+ intersection_matrix = matrix_intersection_area (
508+ text_line_bboxes , table_cell_bboxes
509+ )
510+
511+ # Map cells -> list of assigned text lines
512+ cell_text = defaultdict (list )
513+ for text_line_idx , table_text_line in enumerate (table_text_lines ):
514+ intersections = intersection_matrix [text_line_idx ]
515+ if intersections .sum () == 0 :
516+ continue
517+ max_intersection = intersections .argmax ()
518+ cell_text [max_intersection ].append (table_text_line )
519+
520+ # Adjust cell polygons in place
521+ for cell_idx , cell in enumerate (table_cells ):
522+ # all intersecting lines
523+ intersecting_line_indices = [
524+ i for i , area in enumerate (intersection_matrix [:, cell_idx ]) if area > 0
525+ ]
526+ if not intersecting_line_indices :
527+ continue
528+
529+ assigned_lines = cell_text .get (cell_idx , [])
530+ # Expand to fit assigned lines - **Only in the y direction**
531+ for assigned_line in assigned_lines :
532+ x1 = cell .bbox [0 ]
533+ x2 = cell .bbox [2 ]
534+ y1 = min (cell .bbox [1 ], assigned_line .bbox [1 ])
535+ y2 = max (cell .bbox [3 ], assigned_line .bbox [3 ])
536+ cell .polygon = [[x1 , y1 ], [x2 , y1 ], [x2 , y2 ], [x1 , y2 ]]
537+
538+ # Clear out non-assigned lines
539+ non_assigned_lines = [
540+ table_text_lines [i ]
541+ for i in intersecting_line_indices
542+ if table_text_lines [i ] not in cell_text .get (cell_idx , [])
543+ ]
544+ if non_assigned_lines :
545+ # Find top-most and bottom-most non-assigned boxes
546+ top_box = min (
547+ non_assigned_lines , key = lambda line : line .bbox [1 ]
548+ ) # smallest y0
549+ bottom_box = max (
550+ non_assigned_lines , key = lambda line : line .bbox [3 ]
551+ ) # largest y1
552+
553+ # Current cell bbox (from polygon)
554+ x0 , y0 , x1 , y1 = cell .bbox
555+
556+ # Adjust y-limits based on non-assigned boxes
557+ new_y0 = max (y0 , top_box .bbox [3 ]) # top moves down
558+ new_y1 = min (y1 , bottom_box .bbox [1 ]) # bottom moves up
559+
560+ if new_y0 < new_y1 :
561+ # Replace polygon with a new shrunken rectangle
562+ cell .polygon = [
563+ [x0 , new_y0 ],
564+ [x1 , new_y0 ],
565+ [x1 , new_y1 ],
566+ [x0 , new_y1 ],
567+ ]
568+
569+ def needs_ocr (self , tables : List [TableResult ], table_blocks : List [dict ]):
483570 ocr_tables = []
484- ocr_polys = []
485571 ocr_idxs = []
486- for j , result in enumerate (tables ):
487- table_cells : List [SuryaTableCell ] = result .cells
488- if any ([tc .text_lines is None for tc in table_cells ]):
489- ocr_tables .append (result )
490- polys = [tc for tc in table_cells if tc .text_lines is None ]
491- ocr_polys .append (polys )
572+ for j , (table_result , table_block ) in enumerate (zip (tables , table_blocks )):
573+ table_cells : List [SuryaTableCell ] = table_result .cells
574+ text_lines_need_ocr = any ([tc .text_lines is None for tc in table_cells ])
575+ if (
576+ table_block ["ocr_block" ]
577+ and text_lines_need_ocr
578+ and not self .disable_ocr
579+ ):
580+ logger .debug (
581+ f"Table { j } needs OCR, info table block needs ocr: { table_block ['ocr_block' ]} , text_lines { text_lines_need_ocr } "
582+ )
583+ ocr_tables .append (table_result )
492584 ocr_idxs .append (j )
585+
586+ detection_results : List [TextDetectionResult ] = self .detection_model (
587+ images = [table_blocks [i ]["table_image" ] for i in ocr_idxs ],
588+ batch_size = self .get_detection_batch_size (),
589+ )
590+ assert len (detection_results ) == len (ocr_idxs ), (
591+ "Every OCRed table requires a text detection result"
592+ )
593+
594+ for idx , table_detection_result in zip (ocr_idxs , detection_results ):
595+ self .align_table_cells (tables [idx ], table_detection_result )
596+
597+ ocr_polys = []
598+ for ocr_idx in ocr_idxs :
599+ table_cells = tables [ocr_idx ].cells
600+ polys = [tc for tc in table_cells if tc .text_lines is None ]
601+ ocr_polys .append (polys )
493602 return ocr_tables , ocr_polys , ocr_idxs
494603
495- def get_ocr_results (self , table_images : List [Image .Image ], ocr_polys : List [List [SuryaTableCell ]]):
496- ocr_polys_blank = []
604+ def get_ocr_results (
605+ self , table_images : List [Image .Image ], ocr_polys : List [List [SuryaTableCell ]]
606+ ):
607+ ocr_polys_bad = []
497608
498609 for table_image , polys in zip (table_images , ocr_polys ):
499- table_polys_blank = [is_blank_image (table_image .crop (poly .bbox ), poly .polygon ) for poly in polys ]
500- ocr_polys_blank .append (table_polys_blank )
501-
610+ table_polys_bad = [
611+ any (
612+ [
613+ poly .height < 6 ,
614+ is_blank_image (table_image .crop (poly .bbox ), poly .polygon ),
615+ ]
616+ )
617+ for poly in polys
618+ ]
619+ ocr_polys_bad .append (table_polys_bad )
620+
502621 filtered_polys = []
503- for table_polys , table_polys_blank in zip (ocr_polys , ocr_polys_blank ):
622+ for table_polys , table_polys_bad in zip (ocr_polys , ocr_polys_bad ):
504623 filtered_table_polys = []
505- for p , is_blank in zip (table_polys , table_polys_blank ):
506- if is_blank :
624+ for p , is_bad in zip (table_polys , table_polys_bad ):
625+ if is_bad :
507626 continue
508627 polygon = p .polygon
509628 # Round the polygon
@@ -527,19 +646,21 @@ def get_ocr_results(self, table_images: List[Image.Image], ocr_polys: List[List[
527646 )
528647
529648 # Re-align the predictions to the original length, since we skipped some predictions
530- for table_ocr_result , table_polys_blank in zip (ocr_results , ocr_polys_blank ):
649+ for table_ocr_result , table_polys_bad in zip (ocr_results , ocr_polys_bad ):
531650 updated_lines = []
532651 idx = 0
533- for is_blank in table_polys_blank :
534- if is_blank :
535- updated_lines .append (TextLine (
536- text = "" ,
537- polygon = [[0 , 0 ], [0 , 0 ], [0 , 0 ], [0 , 0 ]],
538- confidence = 1 ,
539- chars = [],
540- original_text_good = False ,
541- words = None
542- ))
652+ for is_bad in table_polys_bad :
653+ if is_bad :
654+ updated_lines .append (
655+ TextLine (
656+ text = "" ,
657+ polygon = [[0 , 0 ], [0 , 0 ], [0 , 0 ], [0 , 0 ]],
658+ confidence = 1 ,
659+ chars = [],
660+ original_text_good = False ,
661+ words = None ,
662+ )
663+ )
543664 else :
544665 updated_lines .append (table_ocr_result .text_lines [idx ])
545666 idx += 1
@@ -548,7 +669,7 @@ def get_ocr_results(self, table_images: List[Image.Image], ocr_polys: List[List[
548669 return ocr_results
549670
550671 def assign_ocr_lines (self , tables : List [TableResult ], table_blocks : list ):
551- ocr_tables , ocr_polys , ocr_idxs = self .needs_ocr (tables )
672+ ocr_tables , ocr_polys , ocr_idxs = self .needs_ocr (tables , table_blocks )
552673 det_images = [
553674 t ["table_image" ] for i , t in enumerate (table_blocks ) if i in ocr_idxs
554675 ]
@@ -589,3 +710,10 @@ def get_recognition_batch_size(self):
589710 elif settings .TORCH_DEVICE_MODEL == "cuda" :
590711 return 48
591712 return 32
713+
714+ def get_detection_batch_size (self ):
715+ if self .detection_batch_size is not None :
716+ return self .detection_batch_size
717+ elif settings .TORCH_DEVICE_MODEL == "cuda" :
718+ return 10
719+ return 4
0 commit comments