Skip to content

Commit 0758ad1

Browse files
authored
feat: Performance optimizations for reading order and table model (#115)
Signed-off-by: Christoph Auer <[email protected]>
1 parent 3e495b9 commit 0758ad1

File tree

4 files changed

+90
-29
lines changed

4 files changed

+90
-29
lines changed

docling_ibm_models/reading_order/reading_order_rb.py

Lines changed: 66 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
#
55
import copy
66
import logging
7-
import os
87
import re
9-
from collections.abc import Iterable
108
from typing import Dict, List, Set, Tuple
119

1210
from docling_core.types.doc.base import BoundingBox, Size
1311
from docling_core.types.doc.document import RefItem
1412
from docling_core.types.doc.labels import DocItemLabel
1513
from pydantic import BaseModel
14+
from rtree import index as rtree_index
1615

1716

1817
class PageElement(BoundingBox):
@@ -306,59 +305,97 @@ def _init_l2r_map(self, page_elems: List[PageElement]):
306305
self.l2r_map[i] = j
307306
self.r2l_map[j] = i
308307

309-
def _init_ud_maps(self, page_elems: List[PageElement]):
308+
def _init_ud_maps(self, page_elems: List[PageElement]) -> None:
309+
"""
310+
Initialize up/down maps for reading order prediction using R-tree spatial indexing.
311+
312+
Uses R-tree for spatial queries.
313+
Determines linear reading sequence by finding preceding/following elements.
314+
"""
310315
self.up_map = {}
311316
self.dn_map = {}
312317

313318
for i, pelem_i in enumerate(page_elems):
314319
self.up_map[i] = []
315320
self.dn_map[i] = []
316321

317-
for j, pelem_j in enumerate(page_elems):
322+
# Build R-tree spatial index
323+
spatial_idx = rtree_index.Index()
324+
for i, pelem in enumerate(page_elems):
325+
spatial_idx.insert(i, (pelem.l, pelem.b, pelem.r, pelem.t))
318326

327+
for j, pelem_j in enumerate(page_elems):
319328
if j in self.r2l_map:
320329
i = self.r2l_map[j]
321-
322330
self.dn_map[i] = [j]
323331
self.up_map[j] = [i]
324-
325332
continue
326333

327-
for i, pelem_i in enumerate(page_elems):
334+
# Find elements above current that might precede it in reading order
335+
query_bbox = (pelem_j.l - 0.1, pelem_j.t, pelem_j.r + 0.1, float("inf"))
336+
candidates = list(spatial_idx.intersection(query_bbox))
328337

338+
for i in candidates:
329339
if i == j:
330340
continue
331341

332-
is_horizontally_connected: bool = False
333-
is_i_just_above_j: bool = pelem_i.overlaps_horizontally(
334-
pelem_j
335-
) and pelem_i.is_strictly_above(pelem_j)
336-
337-
for w, pelem_w in enumerate(page_elems):
338-
339-
if not is_horizontally_connected:
340-
is_horizontally_connected = pelem_w.is_horizontally_connected(
341-
pelem_i, pelem_j
342-
)
342+
pelem_i = page_elems[i]
343343

344-
# ensure there is no other element that is between i and j vertically
345-
if is_i_just_above_j and (
346-
pelem_i.overlaps_horizontally(pelem_w)
347-
or pelem_j.overlaps_horizontally(pelem_w)
348-
):
349-
i_above_w: bool = pelem_i.is_strictly_above(pelem_w)
350-
w_above_j: bool = pelem_w.is_strictly_above(pelem_j)
351-
352-
is_i_just_above_j = not (i_above_w and w_above_j)
353-
354-
if is_i_just_above_j:
344+
# Check spatial relationship
345+
if not (
346+
pelem_i.is_strictly_above(pelem_j)
347+
and pelem_i.overlaps_horizontally(pelem_j)
348+
):
349+
continue
355350

351+
# Check for interrupting elements
352+
if not self._has_sequence_interruption(
353+
spatial_idx, page_elems, i, j, pelem_i, pelem_j
354+
):
355+
# Follow left-to-right mapping
356356
while i in self.l2r_map:
357357
i = self.l2r_map[i]
358358

359359
self.dn_map[i].append(j)
360360
self.up_map[j].append(i)
361361

362+
def _has_sequence_interruption(
363+
self,
364+
spatial_idx: rtree_index.Index,
365+
page_elems: List[PageElement],
366+
i: int,
367+
j: int,
368+
pelem_i: PageElement,
369+
pelem_j: PageElement,
370+
) -> bool:
371+
"""Check if elements interrupt the reading sequence between i and j."""
372+
# Query R-tree for elements between i and j
373+
x_min = min(pelem_i.l, pelem_j.l) - 1.0
374+
x_max = max(pelem_i.r, pelem_j.r) + 1.0
375+
y_min = pelem_j.t
376+
y_max = pelem_i.b
377+
378+
candidates = list(spatial_idx.intersection((x_min, y_min, x_max, y_max)))
379+
380+
for w in candidates:
381+
if w in (i, j):
382+
continue
383+
384+
pelem_w = page_elems[w]
385+
386+
# Check if w interrupts the i->j sequence
387+
if (
388+
(
389+
pelem_i.overlaps_horizontally(pelem_w)
390+
or pelem_j.overlaps_horizontally(pelem_w)
391+
)
392+
and pelem_i.is_strictly_above(pelem_w)
393+
and pelem_w.is_strictly_above(pelem_j)
394+
):
395+
return True
396+
397+
return False
398+
362399
def _do_horizontal_dilation(self, page_elems, dilated_page_elems):
363400

364401
for i, pelem_i in enumerate(dilated_page_elems):

docling_ibm_models/tableformer/models/table04_rs/transformer_rs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# Copyright IBM Corp. 2024 - 2024
33
# SPDX-License-Identifier: MIT
44
#
5+
6+
57
import logging
68
import math
79
from typing import Optional
@@ -99,6 +101,7 @@ def forward( # type: ignore
99101
tgt,
100102
attn_mask=None, # None, because we only care about the last tag
101103
key_padding_mask=tgt_key_padding_mask,
104+
need_weights=False, # Optimization: Don't compute attention weights
102105
)[0]
103106
tgt_last_tok = tgt_last_tok + self.dropout1(tmp_tgt)
104107
tgt_last_tok = self.norm1(tgt_last_tok)
@@ -110,6 +113,7 @@ def forward( # type: ignore
110113
memory,
111114
attn_mask=memory_mask,
112115
key_padding_mask=memory_key_padding_mask,
116+
need_weights=False, # Optimization: Don't compute attention weights
113117
)[0]
114118
tgt_last_tok = tgt_last_tok + self.dropout2(tmp_tgt)
115119
tgt_last_tok = self.norm2(tgt_last_tok)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ dependencies = [
4242
'docling-core (>=2.19.0,<3.0.0)',
4343
'transformers (>=4.42.0,<5.0.0)',
4444
'numpy (>=1.24.4,<3.0.0)',
45+
"rtree>=1.0.0",
4546
]
4647

4748
[project.urls]

uv.lock

Lines changed: 19 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)