Skip to content

Commit 8e52707

Browse files
authored
Merge pull request #868 from datalab-to/dev
Dev
2 parents f8247e2 + 920fe56 commit 8e52707

File tree

9 files changed

+217
-67
lines changed

9 files changed

+217
-67
lines changed

.github/workflows/cla.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,4 @@ jobs:
2929
path-to-document: 'https://github.com/VikParuchuri/marker/blob/master/CLA.md'
3030
# branch should not be protected
3131
branch: 'master'
32-
allowlist: VikParuchuri
32+
allowlist: VikParuchuri,Sandy

marker/processors/llm/llm_table.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ class LLMTableProcessor(BaseLLMComplexBlockProcessor):
3636
float,
3737
"The maximum width/height ratio for table cells for a table to be considered rotated.",
3838
] = 0.6
39+
max_table_iterations: Annotated[
40+
int,
41+
"The maximum number of iterations to attempt rewriting a table.",
42+
] = 2
3943
table_rewriting_prompt: Annotated[
4044
str,
4145
"The prompt to use for rewriting text.",
@@ -58,6 +62,7 @@ class LLMTableProcessor(BaseLLMComplexBlockProcessor):
5862
2. Analyze the html representation of the table.
5963
3. Write a comparison of the image and the html representation, paying special attention to the column headers matching the correct column values.
6064
4. If the html representation is completely correct, or you cannot read the image properly, then write "No corrections needed." If the html representation has errors, generate the corrected html representation. Output only either the corrected html representation or "No corrections needed."
65+
5. If you made corrections, analyze your corrections against the original image, and provide a score from 1-5, indicating how well the corrected html matches the image, with 5 being perfect.
6166
**Example:**
6267
Input:
6368
```html
@@ -70,7 +75,6 @@ class LLMTableProcessor(BaseLLMComplexBlockProcessor):
7075
<tr>
7176
<td>John</td>
7277
<td>Doe</td>
73-
<td>25</td>
7478
</tr>
7579
</table>
7680
```
@@ -79,6 +83,8 @@ class LLMTableProcessor(BaseLLMComplexBlockProcessor):
7983
```html
8084
No corrections needed.
8185
```
86+
analysis: I did not make any corrections, as the html representation was already accurate.
87+
score: 5
8288
**Input:**
8389
```html
8490
{block_html}
@@ -186,6 +192,7 @@ def rewrite_single_chunk(
186192
block_html: str,
187193
children: List[TableCell],
188194
image: Image.Image,
195+
total_iterations: int = 0,
189196
):
190197
prompt = self.table_rewriting_prompt.replace("{block_html}", block_html)
191198

@@ -202,19 +209,31 @@ def rewrite_single_chunk(
202209
return
203210

204211
corrected_html = corrected_html.strip().lstrip("```html").rstrip("```").strip()
212+
213+
# Re-iterate if low score
214+
total_iterations += 1
215+
score = response.get("score", 5)
216+
analysis = response.get("analysis", "")
217+
logger.debug(f"Got table rewriting score {score} with analysis: {analysis}")
218+
if total_iterations < self.max_table_iterations and score < 4:
219+
logger.info(
220+
f"Table rewriting low score {score}, on iteration {total_iterations}"
221+
)
222+
block_html = corrected_html
223+
return self.rewrite_single_chunk(
224+
page, block, block_html, children, image, total_iterations
225+
)
226+
205227
parsed_cells = self.parse_html_table(corrected_html, block, page)
206228
if len(parsed_cells) <= 1:
207229
block.update_metadata(llm_error_count=1)
230+
logger.debug(f"Table parsing issue, only {len(parsed_cells)} cells found")
208231
return
209232

210233
if not corrected_html.endswith("</table>"):
211-
block.update_metadata(llm_error_count=1)
212-
return
213-
214-
parsed_cell_text = "".join([cell.text for cell in parsed_cells])
215-
orig_cell_text = "".join([cell.text for cell in children])
216-
# Potentially a partial response
217-
if len(parsed_cell_text) < len(orig_cell_text) * 0.5:
234+
logger.debug(
235+
"Table parsing issue, corrected html does not end with </table>"
236+
)
218237
block.update_metadata(llm_error_count=1)
219238
return
220239

@@ -304,3 +323,5 @@ def parse_html_table(
304323
class TableSchema(BaseModel):
305324
comparison: str
306325
corrected_html: str
326+
analysis: str
327+
score: int

marker/processors/table.py

Lines changed: 166 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from PIL import Image
77

88
from ftfy import fix_text
9-
from surya.recognition import RecognitionPredictor, OCRResult, TextLine
9+
from surya.detection import DetectionPredictor, TextDetectionResult
10+
from surya.recognition import RecognitionPredictor, TextLine
1011
from surya.table_rec import TableRecPredictor
1112
from surya.table_rec.schema import TableResult, TableCell as SuryaTableCell
1213
from pdftext.extraction import table_output
@@ -35,6 +36,11 @@ class TableProcessor(BaseProcessor):
3536
"The batch size to use for the table recognition model.",
3637
"Default is None, which will use the default batch size for the model.",
3738
] = None
39+
detection_batch_size: Annotated[
40+
int,
41+
"The batch size to use for the table detection model.",
42+
"Default is None, which will use the default batch size for the model.",
43+
] = None
3844
recognition_batch_size: Annotated[
3945
int,
4046
"The batch size to use for the table recognition model.",
@@ -56,27 +62,34 @@ class TableProcessor(BaseProcessor):
5662
bool,
5763
"Whether to disable the tqdm progress bar.",
5864
] = False
59-
drop_repeated_table_text: Annotated[bool, "Drop repeated text in OCR results."] = False
65+
drop_repeated_table_text: Annotated[bool, "Drop repeated text in OCR results."] = (
66+
False
67+
)
6068
filter_tag_list = ["p", "table", "td", "tr", "th", "tbody"]
6169
disable_ocr_math: Annotated[bool, "Disable inline math recognition in OCR"] = False
70+
disable_ocr: Annotated[bool, "Disable OCR entirely."] = False
6271

6372
def __init__(
6473
self,
6574
recognition_model: RecognitionPredictor,
6675
table_rec_model: TableRecPredictor,
76+
detection_model: DetectionPredictor,
6777
config=None,
6878
):
6979
super().__init__(config)
7080

7181
self.recognition_model = recognition_model
7282
self.table_rec_model = table_rec_model
83+
self.detection_model = detection_model
7384

7485
def __call__(self, document: Document):
7586
filepath = document.filepath # Path to original pdf file
7687

7788
table_data = []
7889
for page in document.pages:
7990
for block in page.contained_blocks(document, self.block_types):
91+
if block.block_type == BlockTypes.Table:
92+
block.polygon = block.polygon.expand(0.01, 0.01)
8093
image = block.get_image(document, highres=True)
8194
image_poly = block.polygon.rescale(
8295
(page.polygon.width, page.polygon.height),
@@ -105,6 +118,9 @@ def __call__(self, document: Document):
105118
[t["table_image"] for t in table_data],
106119
batch_size=self.get_table_rec_batch_size(),
107120
)
121+
assert len(tables) == len(table_data), (
122+
"Number of table results should match the number of tables"
123+
)
108124

109125
# Assign cell text if we don't need OCR
110126
# We do this at a line level
@@ -180,21 +196,21 @@ def finalize_cell_text(self, cell: SuryaTableCell):
180196
# Unspaced sequences: "...", "---", "___", "……"
181197
text = re.sub(r"[.\-_…]{2,}", "", text)
182198
# Remove mathbf formatting if there is only digits with decimals/commas/currency symbols inside
183-
text = re.sub(r'\\mathbf\{([0-9.,$€£]+)\}', r'<b>\1</b>', text)
199+
text = re.sub(r"\\mathbf\{([0-9.,$€£]+)\}", r"<b>\1</b>", text)
184200
# Drop empty tags like \overline{}
185-
text = re.sub(r'\\[a-zA-Z]+\{\s*\}', '', text)
201+
text = re.sub(r"\\[a-zA-Z]+\{\s*\}", "", text)
186202
# Drop \phantom{...} (remove contents too)
187-
text = re.sub(r'\\phantom\{.*?\}', '', text)
203+
text = re.sub(r"\\phantom\{.*?\}", "", text)
188204
# Drop \quad
189-
text = re.sub(r'\\quad', '', text)
205+
text = re.sub(r"\\quad", "", text)
190206
# Drop \,
191-
text = re.sub(r'\\,', '', text)
207+
text = re.sub(r"\\,", "", text)
192208
# Unwrap \mathsf{...}
193-
text = re.sub(r'\\mathsf\{([^}]*)\}', r'\1', text)
209+
text = re.sub(r"\\mathsf\{([^}]*)\}", r"\1", text)
194210
# Handle unclosed tags: keep contents, drop the command
195-
text = re.sub(r'\\[a-zA-Z]+\{([^}]*)$', r'\1', text)
211+
text = re.sub(r"\\[a-zA-Z]+\{([^}]*)$", r"\1", text)
196212
# If the whole string is \text{...} → unwrap
197-
text = re.sub(r'^\s*\\text\{([^}]*)\}\s*$', r'\1', text)
213+
text = re.sub(r"^\s*\\text\{([^}]*)\}\s*$", r"\1", text)
198214

199215
# In case the above steps left no more latex math - We can unwrap
200216
text = unwrap_math(text)
@@ -479,31 +495,134 @@ def assign_pdftext_lines(self, extract_blocks: list, filepath: str):
479495
"Number of tables and table inputs must match"
480496
)
481497

482-
def needs_ocr(self, tables: List[TableResult]):
498+
def align_table_cells(
499+
self, table: TableResult, table_detection_result: TextDetectionResult
500+
):
501+
table_cells = table.cells
502+
table_text_lines = table_detection_result.bboxes
503+
504+
text_line_bboxes = [t.bbox for t in table_text_lines]
505+
table_cell_bboxes = [c.bbox for c in table_cells]
506+
507+
intersection_matrix = matrix_intersection_area(
508+
text_line_bboxes, table_cell_bboxes
509+
)
510+
511+
# Map cells -> list of assigned text lines
512+
cell_text = defaultdict(list)
513+
for text_line_idx, table_text_line in enumerate(table_text_lines):
514+
intersections = intersection_matrix[text_line_idx]
515+
if intersections.sum() == 0:
516+
continue
517+
max_intersection = intersections.argmax()
518+
cell_text[max_intersection].append(table_text_line)
519+
520+
# Adjust cell polygons in place
521+
for cell_idx, cell in enumerate(table_cells):
522+
# all intersecting lines
523+
intersecting_line_indices = [
524+
i for i, area in enumerate(intersection_matrix[:, cell_idx]) if area > 0
525+
]
526+
if not intersecting_line_indices:
527+
continue
528+
529+
assigned_lines = cell_text.get(cell_idx, [])
530+
# Expand to fit assigned lines - **Only in the y direction**
531+
for assigned_line in assigned_lines:
532+
x1 = cell.bbox[0]
533+
x2 = cell.bbox[2]
534+
y1 = min(cell.bbox[1], assigned_line.bbox[1])
535+
y2 = max(cell.bbox[3], assigned_line.bbox[3])
536+
cell.polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
537+
538+
# Clear out non-assigned lines
539+
non_assigned_lines = [
540+
table_text_lines[i]
541+
for i in intersecting_line_indices
542+
if table_text_lines[i] not in cell_text.get(cell_idx, [])
543+
]
544+
if non_assigned_lines:
545+
# Find top-most and bottom-most non-assigned boxes
546+
top_box = min(
547+
non_assigned_lines, key=lambda line: line.bbox[1]
548+
) # smallest y0
549+
bottom_box = max(
550+
non_assigned_lines, key=lambda line: line.bbox[3]
551+
) # largest y1
552+
553+
# Current cell bbox (from polygon)
554+
x0, y0, x1, y1 = cell.bbox
555+
556+
# Adjust y-limits based on non-assigned boxes
557+
new_y0 = max(y0, top_box.bbox[3]) # top moves down
558+
new_y1 = min(y1, bottom_box.bbox[1]) # bottom moves up
559+
560+
if new_y0 < new_y1:
561+
# Replace polygon with a new shrunken rectangle
562+
cell.polygon = [
563+
[x0, new_y0],
564+
[x1, new_y0],
565+
[x1, new_y1],
566+
[x0, new_y1],
567+
]
568+
569+
def needs_ocr(self, tables: List[TableResult], table_blocks: List[dict]):
483570
ocr_tables = []
484-
ocr_polys = []
485571
ocr_idxs = []
486-
for j, result in enumerate(tables):
487-
table_cells: List[SuryaTableCell] = result.cells
488-
if any([tc.text_lines is None for tc in table_cells]):
489-
ocr_tables.append(result)
490-
polys = [tc for tc in table_cells if tc.text_lines is None]
491-
ocr_polys.append(polys)
572+
for j, (table_result, table_block) in enumerate(zip(tables, table_blocks)):
573+
table_cells: List[SuryaTableCell] = table_result.cells
574+
text_lines_need_ocr = any([tc.text_lines is None for tc in table_cells])
575+
if (
576+
table_block["ocr_block"]
577+
and text_lines_need_ocr
578+
and not self.disable_ocr
579+
):
580+
logger.debug(
581+
f"Table {j} needs OCR, info table block needs ocr: {table_block['ocr_block']}, text_lines {text_lines_need_ocr}"
582+
)
583+
ocr_tables.append(table_result)
492584
ocr_idxs.append(j)
585+
586+
detection_results: List[TextDetectionResult] = self.detection_model(
587+
images=[table_blocks[i]["table_image"] for i in ocr_idxs],
588+
batch_size=self.get_detection_batch_size(),
589+
)
590+
assert len(detection_results) == len(ocr_idxs), (
591+
"Every OCRed table requires a text detection result"
592+
)
593+
594+
for idx, table_detection_result in zip(ocr_idxs, detection_results):
595+
self.align_table_cells(tables[idx], table_detection_result)
596+
597+
ocr_polys = []
598+
for ocr_idx in ocr_idxs:
599+
table_cells = tables[ocr_idx].cells
600+
polys = [tc for tc in table_cells if tc.text_lines is None]
601+
ocr_polys.append(polys)
493602
return ocr_tables, ocr_polys, ocr_idxs
494603

495-
def get_ocr_results(self, table_images: List[Image.Image], ocr_polys: List[List[SuryaTableCell]]):
496-
ocr_polys_blank = []
604+
def get_ocr_results(
605+
self, table_images: List[Image.Image], ocr_polys: List[List[SuryaTableCell]]
606+
):
607+
ocr_polys_bad = []
497608

498609
for table_image, polys in zip(table_images, ocr_polys):
499-
table_polys_blank = [is_blank_image(table_image.crop(poly.bbox), poly.polygon) for poly in polys]
500-
ocr_polys_blank.append(table_polys_blank)
501-
610+
table_polys_bad = [
611+
any(
612+
[
613+
poly.height < 6,
614+
is_blank_image(table_image.crop(poly.bbox), poly.polygon),
615+
]
616+
)
617+
for poly in polys
618+
]
619+
ocr_polys_bad.append(table_polys_bad)
620+
502621
filtered_polys = []
503-
for table_polys, table_polys_blank in zip(ocr_polys, ocr_polys_blank):
622+
for table_polys, table_polys_bad in zip(ocr_polys, ocr_polys_bad):
504623
filtered_table_polys = []
505-
for p, is_blank in zip(table_polys, table_polys_blank):
506-
if is_blank:
624+
for p, is_bad in zip(table_polys, table_polys_bad):
625+
if is_bad:
507626
continue
508627
polygon = p.polygon
509628
# Round the polygon
@@ -527,19 +646,21 @@ def get_ocr_results(self, table_images: List[Image.Image], ocr_polys: List[List[
527646
)
528647

529648
# Re-align the predictions to the original length, since we skipped some predictions
530-
for table_ocr_result, table_polys_blank in zip(ocr_results, ocr_polys_blank):
649+
for table_ocr_result, table_polys_bad in zip(ocr_results, ocr_polys_bad):
531650
updated_lines = []
532651
idx = 0
533-
for is_blank in table_polys_blank:
534-
if is_blank:
535-
updated_lines.append(TextLine(
536-
text = "",
537-
polygon=[[0, 0], [0, 0], [0, 0], [0, 0]],
538-
confidence=1,
539-
chars=[],
540-
original_text_good=False,
541-
words=None
542-
))
652+
for is_bad in table_polys_bad:
653+
if is_bad:
654+
updated_lines.append(
655+
TextLine(
656+
text="",
657+
polygon=[[0, 0], [0, 0], [0, 0], [0, 0]],
658+
confidence=1,
659+
chars=[],
660+
original_text_good=False,
661+
words=None,
662+
)
663+
)
543664
else:
544665
updated_lines.append(table_ocr_result.text_lines[idx])
545666
idx += 1
@@ -548,7 +669,7 @@ def get_ocr_results(self, table_images: List[Image.Image], ocr_polys: List[List[
548669
return ocr_results
549670

550671
def assign_ocr_lines(self, tables: List[TableResult], table_blocks: list):
551-
ocr_tables, ocr_polys, ocr_idxs = self.needs_ocr(tables)
672+
ocr_tables, ocr_polys, ocr_idxs = self.needs_ocr(tables, table_blocks)
552673
det_images = [
553674
t["table_image"] for i, t in enumerate(table_blocks) if i in ocr_idxs
554675
]
@@ -589,3 +710,10 @@ def get_recognition_batch_size(self):
589710
elif settings.TORCH_DEVICE_MODEL == "cuda":
590711
return 48
591712
return 32
713+
714+
def get_detection_batch_size(self):
715+
if self.detection_batch_size is not None:
716+
return self.detection_batch_size
717+
elif settings.TORCH_DEVICE_MODEL == "cuda":
718+
return 10
719+
return 4

0 commit comments

Comments
 (0)