Skip to content

Commit 78473fe

Browse files
authored
fix: ReadingOrderPredictor: Make _predict_page threadsafe (#137)
Signed-off-by: lmchr <[email protected]>
1 parent c7c4e0b commit 78473fe

File tree

1 file changed

+103
-79
lines changed

1 file changed

+103
-79
lines changed

docling_ibm_models/reading_order/reading_order_rb.py

Lines changed: 103 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
import copy
66
import logging
77
import re
8-
import sys
8+
from dataclasses import dataclass, field
99
from typing import Dict, List, Set, Tuple
1010

1111
from docling_core.types.doc.base import BoundingBox, Size
1212
from docling_core.types.doc.document import RefItem
1313
from docling_core.types.doc.labels import DocItemLabel
14-
from pydantic import BaseModel
1514
from rtree import index as rtree_index
1615

1716
_log = logging.getLogger(__name__)
@@ -48,6 +47,21 @@ def follows_maintext_order(self, rhs) -> bool:
4847
return self.cid + 1 == rhs.cid
4948

5049

50+
@dataclass
51+
class _ReadingOrderPredictorState:
52+
"""
53+
State container of the reading order of a single page
54+
"""
55+
56+
h2i_map: Dict[int, int] = field(default_factory=dict)
57+
i2h_map: Dict[int, int] = field(default_factory=dict)
58+
l2r_map: Dict[int, int] = field(default_factory=dict)
59+
r2l_map: Dict[int, int] = field(default_factory=dict)
60+
up_map: Dict[int, List[int]] = field(default_factory=dict)
61+
dn_map: Dict[int, List[int]] = field(default_factory=dict)
62+
heads: List[int] = field(default_factory=list)
63+
64+
5165
class ReadingOrderPredictor:
5266
r"""
5367
Rule based reading order for DoclingDocument
@@ -59,20 +73,6 @@ def __init__(self):
5973
# Apply horizontal dilation only if it is less than this page-width normalized threshold
6074
self._horizontal_dilation_threshold_norm = 0.15
6175

62-
self.initialise()
63-
64-
def initialise(self):
65-
self.h2i_map: Dict[int, int] = {}
66-
self.i2h_map: Dict[int, int] = {}
67-
68-
self.l2r_map: Dict[int, int] = {}
69-
self.r2l_map: Dict[int, int] = {}
70-
71-
self.up_map: Dict[int, List[int]] = {}
72-
self.dn_map: Dict[int, List[int]] = {}
73-
74-
self.heads: List[int] = []
75-
7676
def predict_reading_order(
7777
self, page_elements: List[PageElement]
7878
) -> List[PageElement]:
@@ -217,10 +217,10 @@ def predict_merges(
217217

218218
def _predict_page(self, page_elements: List[PageElement]) -> List[PageElement]:
219219
r"""
220-
Reorder the output of the
220+
Reorder the output of the page elements into a single-page reading order.
221221
"""
222222

223-
self.initialise()
223+
state = _ReadingOrderPredictorState()
224224

225225
"""
226226
for i, elem in enumerate(page_elements):
@@ -231,50 +231,49 @@ def _predict_page(self, page_elements: List[PageElement]) -> List[PageElement]:
231231
page_elements[i] = elem.to_bottom_left_origin( # type: ignore
232232
page_height=page_elements[i].page_size.height
233233
)
234+
self._init_h2i_map(page_elements, state)
234235

235-
self._init_h2i_map(page_elements)
236+
self._init_l2r_map(page_elements, state)
236237

237-
self._init_l2r_map(page_elements)
238-
239-
self._init_ud_maps(page_elements)
238+
self._init_ud_maps(page_elements, state)
240239

241240
if self.dilated_page_element:
242241
dilated_page_elements: List[PageElement] = copy.deepcopy(
243242
page_elements
244243
) # deep-copy
245244

246245
dilated_page_elements = self._do_horizontal_dilation(
247-
page_elements, dilated_page_elements
246+
page_elements, dilated_page_elements, state
248247
)
249248

250249
# redo with dilated provs
251-
self._init_ud_maps(dilated_page_elements)
250+
self._init_ud_maps(dilated_page_elements, state)
252251

253-
self._find_heads(page_elements)
252+
self._find_heads(page_elements, state)
254253

255-
self._sort_ud_maps(page_elements)
254+
self._sort_ud_maps(page_elements, state)
256255

257256
"""
258-
print(f"heads: {self.heads}")
257+
print(f"heads: {state.heads}")
259258
260259
print("l2r: ")
261-
for k,v in self.l2r_map.items():
260+
for k,v in state.l2r_map.items():
262261
print(f" -> {k}: {v}")
263262
264263
print("r2l: ")
265-
for k,v in self.r2l_map.items():
264+
for k,v in state.r2l_map.items():
266265
print(f" -> {k}: {v}")
267266
268267
print("up: ")
269-
for k,v in self.up_map.items():
268+
for k,v in state.up_map.items():
270269
print(f" -> {k}: {v}")
271270
272271
print("dn: ")
273-
for k,v in self.dn_map.items():
272+
for k,v in state.dn_map.items():
274273
print(f" -> {k}: {v}")
275274
"""
276275

277-
order: List[int] = self._find_order(page_elements)
276+
order: List[int] = self._find_order(page_elements, state)
278277
# print(f"order: {order}")
279278

280279
sorted_elements: List[PageElement] = []
@@ -288,17 +287,21 @@ def _predict_page(self, page_elements: List[PageElement]) -> List[PageElement]:
288287

289288
return sorted_elements
290289

291-
def _init_h2i_map(self, page_elems: List[PageElement]):
292-
self.h2i_map = {}
293-
self.i2h_map = {}
290+
def _init_h2i_map(
291+
self, page_elems: List[PageElement], state: _ReadingOrderPredictorState
292+
) -> None:
293+
state.h2i_map = {}
294+
state.i2h_map = {}
294295

295296
for i, pelem in enumerate(page_elems):
296-
self.h2i_map[pelem.cid] = i
297-
self.i2h_map[i] = pelem.cid
297+
state.h2i_map[pelem.cid] = i
298+
state.i2h_map[i] = pelem.cid
298299

299-
def _init_l2r_map(self, page_elems: List[PageElement]):
300-
self.l2r_map = {}
301-
self.r2l_map = {}
300+
def _init_l2r_map(
301+
self, page_elems: List[PageElement], state: _ReadingOrderPredictorState
302+
) -> None:
303+
state.l2r_map = {}
304+
state.r2l_map = {}
302305

303306
# this currently leads to errors ... might be necessary in the future ...
304307
for i, pelem_i in enumerate(page_elems):
@@ -309,33 +312,35 @@ def _init_l2r_map(self, page_elems: List[PageElement]):
309312
and pelem_i.is_strictly_left_of(pelem_j)
310313
and pelem_i.overlaps_vertically_with_iou(pelem_j, 0.8)
311314
):
312-
self.l2r_map[i] = j
313-
self.r2l_map[j] = i
315+
state.l2r_map[i] = j
316+
state.r2l_map[j] = i
314317

315-
def _init_ud_maps(self, page_elems: List[PageElement]) -> None:
318+
def _init_ud_maps(
319+
self, page_elems: List[PageElement], state: _ReadingOrderPredictorState
320+
) -> None:
316321
"""
317322
Initialize up/down maps for reading order prediction using R-tree spatial indexing.
318323
319324
Uses R-tree for spatial queries.
320325
Determines linear reading sequence by finding preceding/following elements.
321326
"""
322-
self.up_map = {}
323-
self.dn_map = {}
327+
state.up_map = {}
328+
state.dn_map = {}
324329

325330
for i, pelem_i in enumerate(page_elems):
326-
self.up_map[i] = []
327-
self.dn_map[i] = []
331+
state.up_map[i] = []
332+
state.dn_map[i] = []
328333

329334
# Build R-tree spatial index
330335
spatial_idx = rtree_index.Index()
331336
for i, pelem in enumerate(page_elems):
332337
spatial_idx.insert(i, (pelem.l, pelem.b, pelem.r, pelem.t))
333338

334339
for j, pelem_j in enumerate(page_elems):
335-
if j in self.r2l_map:
336-
i = self.r2l_map[j]
337-
self.dn_map[i] = [j]
338-
self.up_map[j] = [i]
340+
if j in state.r2l_map:
341+
i = state.r2l_map[j]
342+
state.dn_map[i] = [j]
343+
state.up_map[j] = [i]
339344
continue
340345

341346
# Find elements above current that might precede it in reading order
@@ -360,11 +365,11 @@ def _init_ud_maps(self, page_elems: List[PageElement]) -> None:
360365
spatial_idx, page_elems, i, j, pelem_i, pelem_j
361366
):
362367
# Follow left-to-right mapping
363-
while i in self.l2r_map:
364-
i = self.l2r_map[i]
368+
while i in state.l2r_map:
369+
i = state.l2r_map[i]
365370

366-
self.dn_map[i].append(j)
367-
self.up_map[j].append(i)
371+
state.dn_map[i].append(j)
372+
state.up_map[j].append(i)
368373

369374
def _has_sequence_interruption(
370375
self,
@@ -403,7 +408,12 @@ def _has_sequence_interruption(
403408

404409
return False
405410

406-
def _do_horizontal_dilation(self, page_elems, dilated_page_elems):
411+
def _do_horizontal_dilation(
412+
self,
413+
page_elems: List[PageElement],
414+
dilated_page_elems: List[PageElement],
415+
state: _ReadingOrderPredictorState,
416+
) -> List[PageElement]:
407417
# Compute the dilation threshold
408418
th = 0.0
409419
if page_elems:
@@ -418,8 +428,8 @@ def _do_horizontal_dilation(self, page_elems, dilated_page_elems):
418428
x1 = pelem_i.r
419429
y1 = pelem_i.t
420430

421-
if i in self.up_map and len(self.up_map[i]) > 0:
422-
pelem_up = page_elems[self.up_map[i][0]]
431+
if i in state.up_map and len(state.up_map[i]) > 0:
432+
pelem_up = page_elems[state.up_map[i][0]]
423433

424434
# Apply threshold for horizontal dilation
425435
x0_dil = min(x0, pelem_up.l)
@@ -429,8 +439,8 @@ def _do_horizontal_dilation(self, page_elems, dilated_page_elems):
429439
x0 = x0_dil
430440
x1 = x1_dil
431441

432-
if i in self.dn_map and len(self.dn_map[i]) > 0:
433-
pelem_dn = page_elems[self.dn_map[i][0]]
442+
if i in state.dn_map and len(state.dn_map[i]) > 0:
443+
pelem_dn = page_elems[state.dn_map[i][0]]
434444

435445
# Apply threshold for horizontal dilation
436446
x0_dil = min(x0, pelem_dn.l)
@@ -461,9 +471,11 @@ def _do_horizontal_dilation(self, page_elems, dilated_page_elems):
461471

462472
return dilated_page_elems
463473

464-
def _find_heads(self, page_elems: List[PageElement]):
474+
def _find_heads(
475+
self, page_elems: List[PageElement], state: _ReadingOrderPredictorState
476+
) -> None:
465477
head_page_elems = []
466-
for key, vals in self.up_map.items():
478+
for key, vals in state.up_map.items():
467479
if len(vals) == 0:
468480
head_page_elems.append(page_elems[key])
469481

@@ -482,12 +494,14 @@ def _find_heads(self, page_elems: List[PageElement]):
482494
print(f"{l}\t{str(elem)}")
483495
"""
484496

485-
self.heads = []
497+
state.heads = []
486498
for item in head_page_elems:
487-
self.heads.append(self.h2i_map[item.cid])
499+
state.heads.append(state.h2i_map[item.cid])
488500

489-
def _sort_ud_maps(self, provs: List[PageElement]):
490-
for ind_i, vals in self.dn_map.items():
501+
def _sort_ud_maps(
502+
self, provs: List[PageElement], state: _ReadingOrderPredictorState
503+
) -> None:
504+
for ind_i, vals in state.dn_map.items():
491505

492506
child_provs: List[PageElement] = []
493507
for ind_j in vals:
@@ -496,33 +510,37 @@ def _sort_ud_maps(self, provs: List[PageElement]):
496510
# this will invoke __lt__ from PageElements
497511
child_provs = sorted(child_provs)
498512

499-
self.dn_map[ind_i] = []
513+
state.dn_map[ind_i] = []
500514
for child in child_provs:
501-
self.dn_map[ind_i].append(self.h2i_map[child.cid])
515+
state.dn_map[ind_i].append(state.h2i_map[child.cid])
502516

503-
def _find_order(self, provs: List[PageElement]):
517+
def _find_order(
518+
self, provs: List[PageElement], state: _ReadingOrderPredictorState
519+
) -> List[int]:
504520
order: List[int] = []
505521

506522
visited: List[bool] = [False for _ in provs]
507523

508-
for j in self.heads:
524+
for j in state.heads:
509525

510526
if not visited[j]:
511527

512528
order.append(j)
513529
visited[j] = True
514-
self._depth_first_search_downwards(j, order, visited)
530+
self._depth_first_search_downwards(j, order, visited, state)
515531

516532
if len(order) != len(provs):
517533
_log.error("something went wrong")
518534

519535
return order
520536

521-
def _depth_first_search_upwards(self, j: int, visited: List[bool]):
537+
def _depth_first_search_upwards(
538+
self, j: int, visited: List[bool], state: _ReadingOrderPredictorState
539+
) -> int:
522540
"""depth_first_search_upwards without recursion"""
523541
k = j
524542
while True:
525-
inds: List[int] = self.up_map[k]
543+
inds: List[int] = state.up_map[k]
526544
found_not_visited = False
527545
for ind in inds:
528546
if not visited[ind]:
@@ -535,26 +553,30 @@ def _depth_first_search_upwards(self, j: int, visited: List[bool]):
535553
return k
536554

537555
def _depth_first_search_downwards(
538-
self, j: int, order: List[int], visited: List[bool]
539-
):
556+
self,
557+
j: int,
558+
order: List[int],
559+
visited: List[bool],
560+
state: _ReadingOrderPredictorState,
561+
) -> None:
540562
"""depth_first_search_downwards without recursion"""
541563
# The outermost list is the main stack.
542564
# Each list element is a tuple containint the list of the indices to be checked and an offset
543-
stack: List[Tuple[List[int], int]] = [(self.dn_map[j], 0)]
565+
stack: List[Tuple[List[int], int]] = [(state.dn_map[j], 0)]
544566

545567
while stack:
546568
inds, offset = stack[-1]
547569

548570
found_non_visited = False
549571
if offset < len(inds):
550572
for new_offset, i in enumerate(inds[offset:]):
551-
k: int = self._depth_first_search_upwards(i, visited)
573+
k: int = self._depth_first_search_upwards(i, visited, state)
552574

553575
if not visited[k]:
554576
order.append(k)
555577
visited[k] = True
556578
stack[-1] = (inds, new_offset + 1)
557-
stack.append((self.dn_map[k], 0))
579+
stack.append((state.dn_map[k], 0))
558580
found_non_visited = True
559581
break
560582

@@ -662,7 +684,9 @@ def _find_to_captions(
662684
print("to-captions: ", cid_i, ": ", to_item)
663685
"""
664686

665-
def _remove_overlapping_indexes(mapping):
687+
def _remove_overlapping_indexes(
688+
mapping: Dict[int, List[int]]
689+
) -> Dict[int, List[int]]:
666690
used = set()
667691
result = {}
668692
for key, values in sorted(mapping.items()):

0 commit comments

Comments
 (0)