55import copy
66import logging
77import re
8- import sys
8+ from dataclasses import dataclass , field
99from typing import Dict , List , Set , Tuple
1010
1111from docling_core .types .doc .base import BoundingBox , Size
1212from docling_core .types .doc .document import RefItem
1313from docling_core .types .doc .labels import DocItemLabel
14- from pydantic import BaseModel
1514from rtree import index as rtree_index
1615
1716_log = logging .getLogger (__name__ )
@@ -48,6 +47,21 @@ def follows_maintext_order(self, rhs) -> bool:
4847 return self .cid + 1 == rhs .cid
4948
5049
50+ @dataclass
51+ class _ReadingOrderPredictorState :
52+ """
53+ State container of the reading order of a single page
54+ """
55+
56+ h2i_map : Dict [int , int ] = field (default_factory = dict )
57+ i2h_map : Dict [int , int ] = field (default_factory = dict )
58+ l2r_map : Dict [int , int ] = field (default_factory = dict )
59+ r2l_map : Dict [int , int ] = field (default_factory = dict )
60+ up_map : Dict [int , List [int ]] = field (default_factory = dict )
61+ dn_map : Dict [int , List [int ]] = field (default_factory = dict )
62+ heads : List [int ] = field (default_factory = list )
63+
64+
5165class ReadingOrderPredictor :
5266 r"""
5367 Rule based reading order for DoclingDocument
@@ -59,20 +73,6 @@ def __init__(self):
5973 # Apply horizontal dilation only if it is less than this page-width normalized threshold
6074 self ._horizontal_dilation_threshold_norm = 0.15
6175
62- self .initialise ()
63-
64- def initialise (self ):
65- self .h2i_map : Dict [int , int ] = {}
66- self .i2h_map : Dict [int , int ] = {}
67-
68- self .l2r_map : Dict [int , int ] = {}
69- self .r2l_map : Dict [int , int ] = {}
70-
71- self .up_map : Dict [int , List [int ]] = {}
72- self .dn_map : Dict [int , List [int ]] = {}
73-
74- self .heads : List [int ] = []
75-
7676 def predict_reading_order (
7777 self , page_elements : List [PageElement ]
7878 ) -> List [PageElement ]:
@@ -217,10 +217,10 @@ def predict_merges(
217217
218218 def _predict_page (self , page_elements : List [PageElement ]) -> List [PageElement ]:
219219 r"""
220- Reorder the output of the
220+ Reorder the output of the page elements into a single-page reading order.
221221 """
222222
223- self . initialise ()
223+ state = _ReadingOrderPredictorState ()
224224
225225 """
226226 for i, elem in enumerate(page_elements):
@@ -231,50 +231,49 @@ def _predict_page(self, page_elements: List[PageElement]) -> List[PageElement]:
231231 page_elements [i ] = elem .to_bottom_left_origin ( # type: ignore
232232 page_height = page_elements [i ].page_size .height
233233 )
234+ self ._init_h2i_map (page_elements , state )
234235
235- self ._init_h2i_map (page_elements )
236+ self ._init_l2r_map (page_elements , state )
236237
237- self ._init_l2r_map (page_elements )
238-
239- self ._init_ud_maps (page_elements )
238+ self ._init_ud_maps (page_elements , state )
240239
241240 if self .dilated_page_element :
242241 dilated_page_elements : List [PageElement ] = copy .deepcopy (
243242 page_elements
244243 ) # deep-copy
245244
246245 dilated_page_elements = self ._do_horizontal_dilation (
247- page_elements , dilated_page_elements
246+ page_elements , dilated_page_elements , state
248247 )
249248
250249 # redo with dilated provs
251- self ._init_ud_maps (dilated_page_elements )
250+ self ._init_ud_maps (dilated_page_elements , state )
252251
253- self ._find_heads (page_elements )
252+ self ._find_heads (page_elements , state )
254253
255- self ._sort_ud_maps (page_elements )
254+ self ._sort_ud_maps (page_elements , state )
256255
257256 """
258- print(f"heads: {self .heads}")
257+ print(f"heads: {state .heads}")
259258
260259 print("l2r: ")
261- for k,v in self .l2r_map.items():
260+ for k,v in state .l2r_map.items():
262261 print(f" -> {k}: {v}")
263262
264263 print("r2l: ")
265- for k,v in self .r2l_map.items():
264+ for k,v in state .r2l_map.items():
266265 print(f" -> {k}: {v}")
267266
268267 print("up: ")
269- for k,v in self .up_map.items():
268+ for k,v in state .up_map.items():
270269 print(f" -> {k}: {v}")
271270
272271 print("dn: ")
273- for k,v in self .dn_map.items():
272+ for k,v in state .dn_map.items():
274273 print(f" -> {k}: {v}")
275274 """
276275
277- order : List [int ] = self ._find_order (page_elements )
276+ order : List [int ] = self ._find_order (page_elements , state )
278277 # print(f"order: {order}")
279278
280279 sorted_elements : List [PageElement ] = []
@@ -288,17 +287,21 @@ def _predict_page(self, page_elements: List[PageElement]) -> List[PageElement]:
288287
289288 return sorted_elements
290289
291- def _init_h2i_map (self , page_elems : List [PageElement ]):
292- self .h2i_map = {}
293- self .i2h_map = {}
290+ def _init_h2i_map (
291+ self , page_elems : List [PageElement ], state : _ReadingOrderPredictorState
292+ ) -> None :
293+ state .h2i_map = {}
294+ state .i2h_map = {}
294295
295296 for i , pelem in enumerate (page_elems ):
296- self .h2i_map [pelem .cid ] = i
297- self .i2h_map [i ] = pelem .cid
297+ state .h2i_map [pelem .cid ] = i
298+ state .i2h_map [i ] = pelem .cid
298299
299- def _init_l2r_map (self , page_elems : List [PageElement ]):
300- self .l2r_map = {}
301- self .r2l_map = {}
300+ def _init_l2r_map (
301+ self , page_elems : List [PageElement ], state : _ReadingOrderPredictorState
302+ ) -> None :
303+ state .l2r_map = {}
304+ state .r2l_map = {}
302305
303306 # this currently leads to errors ... might be necessary in the future ...
304307 for i , pelem_i in enumerate (page_elems ):
@@ -309,33 +312,35 @@ def _init_l2r_map(self, page_elems: List[PageElement]):
309312 and pelem_i .is_strictly_left_of (pelem_j )
310313 and pelem_i .overlaps_vertically_with_iou (pelem_j , 0.8 )
311314 ):
312- self .l2r_map [i ] = j
313- self .r2l_map [j ] = i
315+ state .l2r_map [i ] = j
316+ state .r2l_map [j ] = i
314317
315- def _init_ud_maps (self , page_elems : List [PageElement ]) -> None :
318+ def _init_ud_maps (
319+ self , page_elems : List [PageElement ], state : _ReadingOrderPredictorState
320+ ) -> None :
316321 """
317322 Initialize up/down maps for reading order prediction using R-tree spatial indexing.
318323
319324 Uses R-tree for spatial queries.
320325 Determines linear reading sequence by finding preceding/following elements.
321326 """
322- self .up_map = {}
323- self .dn_map = {}
327+ state .up_map = {}
328+ state .dn_map = {}
324329
325330 for i , pelem_i in enumerate (page_elems ):
326- self .up_map [i ] = []
327- self .dn_map [i ] = []
331+ state .up_map [i ] = []
332+ state .dn_map [i ] = []
328333
329334 # Build R-tree spatial index
330335 spatial_idx = rtree_index .Index ()
331336 for i , pelem in enumerate (page_elems ):
332337 spatial_idx .insert (i , (pelem .l , pelem .b , pelem .r , pelem .t ))
333338
334339 for j , pelem_j in enumerate (page_elems ):
335- if j in self .r2l_map :
336- i = self .r2l_map [j ]
337- self .dn_map [i ] = [j ]
338- self .up_map [j ] = [i ]
340+ if j in state .r2l_map :
341+ i = state .r2l_map [j ]
342+ state .dn_map [i ] = [j ]
343+ state .up_map [j ] = [i ]
339344 continue
340345
341346 # Find elements above current that might precede it in reading order
@@ -360,11 +365,11 @@ def _init_ud_maps(self, page_elems: List[PageElement]) -> None:
360365 spatial_idx , page_elems , i , j , pelem_i , pelem_j
361366 ):
362367 # Follow left-to-right mapping
363- while i in self .l2r_map :
364- i = self .l2r_map [i ]
368+ while i in state .l2r_map :
369+ i = state .l2r_map [i ]
365370
366- self .dn_map [i ].append (j )
367- self .up_map [j ].append (i )
371+ state .dn_map [i ].append (j )
372+ state .up_map [j ].append (i )
368373
369374 def _has_sequence_interruption (
370375 self ,
@@ -403,7 +408,12 @@ def _has_sequence_interruption(
403408
404409 return False
405410
406- def _do_horizontal_dilation (self , page_elems , dilated_page_elems ):
411+ def _do_horizontal_dilation (
412+ self ,
413+ page_elems : List [PageElement ],
414+ dilated_page_elems : List [PageElement ],
415+ state : _ReadingOrderPredictorState ,
416+ ) -> List [PageElement ]:
407417 # Compute the dilation threshold
408418 th = 0.0
409419 if page_elems :
@@ -418,8 +428,8 @@ def _do_horizontal_dilation(self, page_elems, dilated_page_elems):
418428 x1 = pelem_i .r
419429 y1 = pelem_i .t
420430
421- if i in self .up_map and len (self .up_map [i ]) > 0 :
422- pelem_up = page_elems [self .up_map [i ][0 ]]
431+ if i in state .up_map and len (state .up_map [i ]) > 0 :
432+ pelem_up = page_elems [state .up_map [i ][0 ]]
423433
424434 # Apply threshold for horizontal dilation
425435 x0_dil = min (x0 , pelem_up .l )
@@ -429,8 +439,8 @@ def _do_horizontal_dilation(self, page_elems, dilated_page_elems):
429439 x0 = x0_dil
430440 x1 = x1_dil
431441
432- if i in self .dn_map and len (self .dn_map [i ]) > 0 :
433- pelem_dn = page_elems [self .dn_map [i ][0 ]]
442+ if i in state .dn_map and len (state .dn_map [i ]) > 0 :
443+ pelem_dn = page_elems [state .dn_map [i ][0 ]]
434444
435445 # Apply threshold for horizontal dilation
436446 x0_dil = min (x0 , pelem_dn .l )
@@ -461,9 +471,11 @@ def _do_horizontal_dilation(self, page_elems, dilated_page_elems):
461471
462472 return dilated_page_elems
463473
464- def _find_heads (self , page_elems : List [PageElement ]):
474+ def _find_heads (
475+ self , page_elems : List [PageElement ], state : _ReadingOrderPredictorState
476+ ) -> None :
465477 head_page_elems = []
466- for key , vals in self .up_map .items ():
478+ for key , vals in state .up_map .items ():
467479 if len (vals ) == 0 :
468480 head_page_elems .append (page_elems [key ])
469481
@@ -482,12 +494,14 @@ def _find_heads(self, page_elems: List[PageElement]):
482494 print(f"{l}\t {str(elem)}")
483495 """
484496
485- self .heads = []
497+ state .heads = []
486498 for item in head_page_elems :
487- self .heads .append (self .h2i_map [item .cid ])
499+ state .heads .append (state .h2i_map [item .cid ])
488500
489- def _sort_ud_maps (self , provs : List [PageElement ]):
490- for ind_i , vals in self .dn_map .items ():
501+ def _sort_ud_maps (
502+ self , provs : List [PageElement ], state : _ReadingOrderPredictorState
503+ ) -> None :
504+ for ind_i , vals in state .dn_map .items ():
491505
492506 child_provs : List [PageElement ] = []
493507 for ind_j in vals :
@@ -496,33 +510,37 @@ def _sort_ud_maps(self, provs: List[PageElement]):
496510 # this will invoke __lt__ from PageElements
497511 child_provs = sorted (child_provs )
498512
499- self .dn_map [ind_i ] = []
513+ state .dn_map [ind_i ] = []
500514 for child in child_provs :
501- self .dn_map [ind_i ].append (self .h2i_map [child .cid ])
515+ state .dn_map [ind_i ].append (state .h2i_map [child .cid ])
502516
503- def _find_order (self , provs : List [PageElement ]):
517+ def _find_order (
518+ self , provs : List [PageElement ], state : _ReadingOrderPredictorState
519+ ) -> List [int ]:
504520 order : List [int ] = []
505521
506522 visited : List [bool ] = [False for _ in provs ]
507523
508- for j in self .heads :
524+ for j in state .heads :
509525
510526 if not visited [j ]:
511527
512528 order .append (j )
513529 visited [j ] = True
514- self ._depth_first_search_downwards (j , order , visited )
530+ self ._depth_first_search_downwards (j , order , visited , state )
515531
516532 if len (order ) != len (provs ):
517533 _log .error ("something went wrong" )
518534
519535 return order
520536
521- def _depth_first_search_upwards (self , j : int , visited : List [bool ]):
537+ def _depth_first_search_upwards (
538+ self , j : int , visited : List [bool ], state : _ReadingOrderPredictorState
539+ ) -> int :
522540 """depth_first_search_upwards without recursion"""
523541 k = j
524542 while True :
525- inds : List [int ] = self .up_map [k ]
543+ inds : List [int ] = state .up_map [k ]
526544 found_not_visited = False
527545 for ind in inds :
528546 if not visited [ind ]:
@@ -535,26 +553,30 @@ def _depth_first_search_upwards(self, j: int, visited: List[bool]):
535553 return k
536554
537555 def _depth_first_search_downwards (
538- self , j : int , order : List [int ], visited : List [bool ]
539- ):
556+ self ,
557+ j : int ,
558+ order : List [int ],
559+ visited : List [bool ],
560+ state : _ReadingOrderPredictorState ,
561+ ) -> None :
540562 """depth_first_search_downwards without recursion"""
541563 # The outermost list is the main stack.
542564 # Each list element is a tuple containint the list of the indices to be checked and an offset
543- stack : List [Tuple [List [int ], int ]] = [(self .dn_map [j ], 0 )]
565+ stack : List [Tuple [List [int ], int ]] = [(state .dn_map [j ], 0 )]
544566
545567 while stack :
546568 inds , offset = stack [- 1 ]
547569
548570 found_non_visited = False
549571 if offset < len (inds ):
550572 for new_offset , i in enumerate (inds [offset :]):
551- k : int = self ._depth_first_search_upwards (i , visited )
573+ k : int = self ._depth_first_search_upwards (i , visited , state )
552574
553575 if not visited [k ]:
554576 order .append (k )
555577 visited [k ] = True
556578 stack [- 1 ] = (inds , new_offset + 1 )
557- stack .append ((self .dn_map [k ], 0 ))
579+ stack .append ((state .dn_map [k ], 0 ))
558580 found_non_visited = True
559581 break
560582
@@ -662,7 +684,9 @@ def _find_to_captions(
662684 print("to-captions: ", cid_i, ": ", to_item)
663685 """
664686
665- def _remove_overlapping_indexes (mapping ):
687+ def _remove_overlapping_indexes (
688+ mapping : Dict [int , List [int ]]
689+ ) -> Dict [int , List [int ]]:
666690 used = set ()
667691 result = {}
668692 for key , values in sorted (mapping .items ()):
0 commit comments